From 11601b1477ff2957c1516ab5cd89d49f82c97a09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillaume=20Mar=C3=A7ais?= Date: Tue, 16 Jul 2024 11:54:27 -0400 Subject: [PATCH] Support for 128 bit base type for mers. * Some program don't run at this size. --- Tupfile | 3 +- common.hpp | 10 ++ configure.sh | 11 ++ mer_op.hpp | 181 ++++++++++++++++++- misc.hpp | 9 +- mt_queue.hpp | 121 +++++++++++++ old_champarnaud_set.cc | 61 ++++--- opt_canon.cc | 392 +++++++++++++++++++++++++---------------- opt_canon.yaggo | 23 ++- simple_thread_pool.hpp | 63 +++++++ sketch_components.cc | 94 +++++----- 11 files changed, 732 insertions(+), 236 deletions(-) create mode 100644 mt_queue.hpp create mode 100644 simple_thread_pool.hpp diff --git a/Tupfile b/Tupfile index 3bd1ab5..7f4b3c1 100644 --- a/Tupfile +++ b/Tupfile @@ -3,7 +3,8 @@ include_rules export PKG_CONFIG_PATH # CXXFLAGS=-O3 -DNDEBUG -Wall -Werror -DHAVE_EXECINFO_H `pkg-config --cflags libxxhash` -pthread -CXXFLAGS=-Wall -Werror -DHAVE_EXECINFO_H -I$(TUP_VARIANTDIR) -pthread -std=c++20 -DHAVE_INT128 +# CXXFLAGS=-Wall -Werror -DHAVE_EXECINFO_H -I$(TUP_VARIANTDIR) -pthread -std=c++20 -DHAVE_INT128 +CXXFLAGS=-Wall -DHAVE_EXECINFO_H -I$(TUP_VARIANTDIR) -pthread -std=gnu++20 -DHAVE_INT128 LDFLAGS=-pthread LDLIBS= diff --git a/common.hpp b/common.hpp index c508822..978ea68 100644 --- a/common.hpp +++ b/common.hpp @@ -7,6 +7,7 @@ #include #include #include +#include // State of a mer. Either it is nil (unknown), no (absent), yes (present) or // blocked (should not take part in an F-move). @@ -118,4 +119,13 @@ std::ostream& operator<<(std::ostream& os, const std::pair& p) { return os << p.first << ':' << p.second; } +template +std::ostream& operator<<(std::ostream& os, const std::optional& x) { + if(x) + os << "Some(" << *x << ')'; + else + os << "None"; + return os; +} + #endif // COMMON_H_ diff --git a/configure.sh b/configure.sh index 5f54ebd..1e04e0c 100755 --- a/configure.sh +++ b/configure.sh @@ -15,6 +15,7 @@ ALPHA (the alphabet size) and K (the k-mer length) are required. Extra compilation flags or option can be passed in the following environment variables: + CXX Path to g++ version at least 12 CXXFLAGS Compilation flags LDFLAGS Linker flags LDLIBS Extra libraries flags @@ -90,12 +91,22 @@ done [ -z "$ILPPYTHON" ] && echo >&2 "No ILP: didn't find a satisfying Python interpreter and packages" fi +# Find a valid version for g++ +GCXX= +for gcc in $CXX g++ g++-12; do + $gcc -o check_gcc_version -O0 check_gcc_version.cc + ./check_gcc_version 12 0 && GCXX=$gcc && break +done +rm -f check_gcc_version +[ -z "$GCXX" ] && { echo >&2 "Didn't find g++ version at least 12.0"; false; } + mkdir -p configs confFile=configs/${NAME}.config tmpFile=${confFile}.tmp cat > "$tmpFile" < +// #include #include #include #include #include #include +#include +#include + +// Print 128 bit long integers. Not very fast. Ignore formatting +std::ostream& operator<<(std::ostream& os, __uint128_t x) { + static constexpr int buflen = 40; + char buf[buflen]; + char* ptr = &buf[buflen - 1]; + *ptr = '\0'; + + do { + --ptr; + *ptr = ((char)(x % 10 + '0')); + x /= 10; + } while(x > 0); + return os << ptr; +} // Number of bits to encode a^k constexpr unsigned int log2ak(unsigned int a, unsigned k) { @@ -44,7 +61,8 @@ struct optimal_int { }; template -constexpr typename std::enable_if::value, T>::type +// constexpr typename std::enable_if::value, T>::type +constexpr T ipow(T base, unsigned int exp) { T result = 1; for(;;) { @@ -65,10 +83,64 @@ constexpr size_t nb_necklaces(unsigned a, unsigned k) { return res / k; } +// Bit twiddling operation to reverse bits in a word. Used for reverse +// complementation when the alphabet is a power of 2 (alpha == 2 or 4). Only +// defined on unsigned word types. +// +// Checkered mask. cmask is every other bit on (0x55). +// cmask is two bits one, two bits off (0x33). Etc. +template +struct cmask { + static constexpr + std::enable_if::value && std::is_integral::value, U>::type + v = (cmask::v << (2 * len)) | (((U)1 << len) - 1); +}; + +// When len is half of the word size, shifting by (2 * len) is undefined +// behavior. Fix it here. +template +struct cmask { + static constexpr + std::enable_if::value && std::is_integral::value, U>::type + v = (((U)1 << (sizeof(U) * 4)) - 1); +}; + +// Base case, when l = 0, start with empty (0) word. +template +struct cmask { + static constexpr + std::enable_if::value && std::is_integral::value, U>::type + v = 0; +}; + +template +inline +std::enable_if::value && std::is_integral::value, U>::type +word_reverse_complement(U w) { + if constexpr (alpha == 2) + w = ((w >> 1) & cmask::v) | ((w & cmask::v) << 1); + w = ((w >> 2) & cmask::v) | ((w & cmask::v) << 2); + w = ((w >> 4) & cmask::v) | ((w & cmask::v) << 4); + if constexpr (sizeof(U) >= 2) + w = ((w >> 8) & cmask::v) | ((w & cmask::v) << 8); + if constexpr (sizeof(U) >= 4) + w = ((w >> 16) & cmask::v) | ((w & cmask::v) << 16); + if constexpr (sizeof(U) >= 8) + w = ((w >> 32) & cmask::v) | ((w & cmask::v) << 32); + if constexpr (sizeof(U) >= 16) + w = ((w >> 64) & cmask::v) | ((w & cmask::v) << 64); + // return ~w; + return ((U)-1) - w; +} + template struct mer_op_type { // typedef mer_type mer_t; - typedef typename optimal_int::type mer_t; + constexpr static unsigned int ak_bits = log2ak(alpha_, k_); + typedef typename optimal_int::type mer_t; + + // Skip many program if encoding k takes too many bits + constexpr static unsigned int max_bits = 34; constexpr static unsigned int k = k_; constexpr static unsigned int alpha = alpha_; @@ -161,13 +233,20 @@ struct mer_op_type { return w; } - static mer_t reverse_comp(const mer_t m) { - mer_t res = 0; - mer_t left = m; - for(unsigned int i = 0; i < k; ++i, left /= alpha) { - res = (res * alpha) + (alpha - 1 - (left % alpha)); + inline static mer_t reverse_comp(const mer_t m) { + // Optimize for binary and DNA alphabet with bit twiddling + if constexpr (alpha == 2 || alpha == 4) { + const mer_t wc = word_reverse_complement(m); + constexpr unsigned shift = (8 * sizeof(mer_t) - (alpha/2)*k); + return wc >> shift; + } else { + // For general alphabets, looping algorithm + mer_t res = 0; + mer_t left = m; + for(unsigned int i = 0; i < k; ++i, left /= alpha) + res = (res * alpha) + (alpha - 1 - (left % alpha)); + return res; } - return res; } static mer_t canonical(const mer_t m) { @@ -175,4 +254,88 @@ struct mer_op_type { } }; + +template +struct amer_type { + typedef mer_op_type mer_ops; + typedef mer_ops::mer_t mer_t; + mer_t val; + + amer_type() = default; + amer_type(mer_t x) : val(x) {} + amer_type(const amer_type& rhs) : val(rhs.val) {} + static amer_type homopolymer(const mer_t base) { return mer_ops::homopolymer(base); } + + inline amer_type lb() const { return mer_ops::lb(val); } + inline amer_type rb() const { return mer_ops::rb(val); } + inline amer_type nmer() const { return mer_ops::nmer(val); } + inline amer_type nmer(const mer_t base) const { return mer_ops::nmer(val, base); } + inline amer_type pmer() const { return mer_ops::pmer(val); } + inline amer_type pmer(const mer_t base) const { return mer_ops::pmer(val, base); } + inline amer_type fmove() const { return mer_ops::fmove(val); } + inline amer_type rfmove() const { return mer_ops::rfmove(val); } + inline bool are_lc(const amer_type& rhs) const { return mer_ops::are_lc(val, rhs.val); } + inline bool ar_rc(const amer_type& rhs) const { return mer_ops::are_rc(val, rhs.val); } + inline amer_type lc(const mer_t base) const { return mer_ops::lc(val, base); } + inline amer_type rc(const mer_t base) const { return mer_ops::rc(val, base); } + inline bool is_homopolymer() const { return mer_ops::is_homopolymer(val); } + inline bool is_homopolymer_fm() const { return mer_ops::is_homopolymer_fm(val); } + inline mer_t weight() const { return mer_ops::weight(val); } + inline amer_type reverse_comp() const { return mer_ops::reverse_comp(val); } + inline amer_type canonical() const { return mer_ops::canonical(val); } + + struct mer_rc_pair { + amer_type mer, rc; + mer_rc_pair(const amer_type& m) + : mer(m) + , rc(m.reverse_comp()) + {} + mer_rc_pair& operator++() { + ++mer.val; + rc.val -= mer_ops::nb_fmoves; + return *this; + } + mer_rc_pair& operator--() { + --mer; + rc += mer_ops::nb_fmoves; + return *this; + } + }; + +}; + +template +std::ostream& operator<<(std::ostream& os, const amer_type& m) { + return os << (uint64_t)m.val; +} + +template +inline bool operator==(const amer_type& x, const amer_type& y) { + return x.val == y.val; +} + +template +inline bool operator!=(const amer_type& x, const amer_type& y) { + return x.val != y.val; +} + +template +inline bool operator<(const amer_type& x, const amer_type& y) { + return x.val < y.val; +} + + +template +struct std::hash> { + typedef amer_type amer_t; + std::size_t operator()(const amer_t& x) const noexcept { + return std::hash{}(x.val); + } +}; + +template +inline std::ostream& operator<<(std::ostream& os, const typename amer_type::mer_rc_pair& pair) { + return os << '(' << pair.mer << ',' << pair.rc << ')'; +} + #endif // MER_OP_H diff --git a/misc.hpp b/misc.hpp index 86b8fb3..dbc0ba1 100644 --- a/misc.hpp +++ b/misc.hpp @@ -74,12 +74,17 @@ std::vector mds_from_file(const char* path_mds, bool sort = true) { } template -C get_mds(const char* path_mds, std::vector args) { - C res; +void get_mds(const char* path_mds, std::vector args, C& res) { if(path_mds != nullptr && path_mds[0] != '\0') mds_read_to(path_mds, res); if(args.size() > 0) mds_parse_to(args, res); +} + +template +inline C get_mds(const char* path_mds, std::vector args) { + C res; + get_mds(path_mds, args, res); return res; } diff --git a/mt_queue.hpp b/mt_queue.hpp new file mode 100644 index 0000000..92868f2 --- /dev/null +++ b/mt_queue.hpp @@ -0,0 +1,121 @@ +#ifndef MT_QUEUE_H_ +#define MT_QUEUE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// /* base-2 logarithm, rounding down */ +// static inline constexpr uint64_t lg_down(uint64_t x) { +// return 63U - __builtin_clzl(x); +// } + +// /* base-2 logarithm, rounding up */ +// static inline constexpr uint64_t lg_up(uint64_t x) { +// return lg_down(x - 1) + 1; +// } + +// A multi-threaded queue for BFS. Two queues really for the current frontier +// and the next frontier. Pop pulls for the current, push appends to the next. +// Pull returns an empty std::optional element when empty. Use swap() to move to +// the next level. +// +// Objects are not deleted from the queue. Not sure that there is a trait that +// says that's OK. +template +class mt_queue { +public: + typedef T value_type; + typedef value_type& reference; + typedef const value_type& const_reference; + typedef value_type* pointer; + typedef const value_type* const_pointer; + typedef ssize_t difference_type; + typedef size_t size_type; + +protected: + std::vector _current; + size_type _current_size; // Number of slots filled + std::atomic _index; // Location to pop + + std::vector _next;// all pointers in _next. + std::atomic _next_size; // Number of slots filled + // + + const size_t chunk; + +public: + mt_queue(size_type size) + : _current(size, 0) + , _current_size(0) + , _index(0) + , _next(size, 0) + , _next_size(0) + , chunk(std::min(getpagesize() / sizeof(T), size)) + { // std::cout << "chunk " << chunk << ' ' << sizeof(T) << ' ' << size << std::endl; + } + + // Is current frontier empty? + inline bool current_empty() const { return _current_size == 0; } + + // Swap: make the next frontier be the current. Should be call after getting + // an empty element from pop(), to start eploring the next level + void swap(bool copy_next = true) { + // std::cout << "swap " << _current_size << ' ' << _next_size << std::endl; + std::swap(_current, _next); + _current_size = _next_size; + _index = 0; + _next_size = 0; + } + + // Pop an element from the current frontier. Returns a std::optional + auto pop() { + const auto i = _index++; + return i < _current_size ? std::optional{_current[i]} : std::nullopt; + } + + std::pair multi_pop() { + const auto pos = _index.fetch_add(chunk); + return std::make_pair(_current.data() + pos, std::min((ssize_t)chunk, (ssize_t)_current_size - (ssize_t)pos)); + } + + + // Push an element to the next frontier. + void push(const T& x) { + const auto i = _next_size++; + _next[i] = x; + } + + std::pair multi_push() { + const auto pos = _next_size.fetch_add(chunk); + return std::make_pair(_next.data() + pos, chunk); + } + + void clear() { + _current_size = 0; + _index = 0; + _next_size = 0; + } + + template + friend std::ostream& operator<<(std::ostream&, const mt_queue&); +}; + +template +std::ostream& operator<<(std::ostream& os, const mt_queue& q) { + os << "current<" << q._index << ',' << q._current_size << ',' << (void*)&q._current[0] << ">["; + for(size_t i = 0; i < q._current_size; ++i) + os << q._current[i] << ' '; + os << "] next<" << q._next_size << ',' << (void*)&q._next[0] << ">["; + for(size_t i = 0; i < q._next_size; ++i) + os << q._next[i] << ' '; + return os << ']'; +} + +#endif // MT_QUEUE_H_ diff --git a/old_champarnaud_set.cc b/old_champarnaud_set.cc index 98ffe43..9b6a7ad 100644 --- a/old_champarnaud_set.cc +++ b/old_champarnaud_set.cc @@ -112,32 +112,47 @@ mer_t champarnaud_mer(mer_t m) { return primary_division<1, mer_ops::nb_mers / mer_ops::alpha>::divide(m); } -int main(int argc, char* argv[]) { - old_champarnaud_set args(argc, argv); +template +struct amain { + int operator()(const old_champarnaud_set& args) { + std::cerr << "Problem size too big" << std::endl; + return EXIT_FAILURE; + } +}; + +template +struct amain { + int operator()(const old_champarnaud_set& args) { + std::bitset done; + std::vector res; + + for(mer_t m = 0; m < mer_ops::nb_mers; ++m) { + if(done.test(m)) continue; // Already done that PCR + + // By definition, m is the smallest. Print the PCR and whether m is a + // Lyndon word. + // std::cout << (size_t)m; + done.set(m); + for(mer_t nm = mer_ops::nmer(m); nm != m; nm = mer_ops::nmer(nm)) { + // std::cout << ',' << (size_t)nm; + done.set(nm); + } + // std::cout << ": " << is_lyndon(m) << '\n'; + const auto mds_member = champarnaud_mer(m); + // std::cout << "->" << (size_t)mds_member << "<-\n"; + res.push_back(mds_member); + } + std::sort(res.begin(), res.end()); + std::cout << joinT(res, ',') << '\n'; - // std::cout << "4 5 " << is_lyndon<4>(5) << '\n'; + return EXIT_SUCCESS; + } +}; - std::bitset done; - std::vector res; - for(mer_t m = 0; m < mer_ops::nb_mers; ++m) { - if(done.test(m)) continue; // Already done that PCR +int main(int argc, char* argv[]) { + const old_champarnaud_set args(argc, argv); - // By definition, m is the smallest. Print the PCR and whether m is a - // Lyndon word. - // std::cout << (size_t)m; - done.set(m); - for(mer_t nm = mer_ops::nmer(m); nm != m; nm = mer_ops::nmer(nm)) { - // std::cout << ',' << (size_t)nm; - done.set(nm); - } - // std::cout << ": " << is_lyndon(m) << '\n'; - const auto mds_member = champarnaud_mer(m); - // std::cout << "->" << (size_t)mds_member << "<-\n"; - res.push_back(mds_member); - } - std::sort(res.begin(), res.end()); - std::cout << joinT(res, ',') << '\n'; + return amain()(args); - return EXIT_SUCCESS; } diff --git a/opt_canon.cc b/opt_canon.cc index bbe083e..cb8f74e 100644 --- a/opt_canon.cc +++ b/opt_canon.cc @@ -1,16 +1,22 @@ #include -#include +#include #include -#include -#include +#include #include +#include +#include +#include +#include +#include +#include +#include "opt_canon.hpp" #include "misc.hpp" #include "common.hpp" -#include "tarjan_scc.hpp" +#include "mt_queue.hpp" #include "random_seed.hpp" +#include "simple_thread_pool.hpp" -#include "opt_canon.hpp" #ifndef K #error Must define k-mer length K @@ -21,65 +27,136 @@ #endif #include "mer_op.hpp" +typedef amer_type amer_t; +typedef amer_t::mer_ops mer_ops; +typedef amer_t::mer_t mer_t; -typedef mer_op_type mer_ops; -typedef mer_ops::mer_t mer_t; +// Does a BFS to detect a new cycle in the de Bruijn graph minus a set. Starts +// from m and the reverse complement of m (rcm) and check for a loop back to m +// or back to rcm. +template +struct symm_bfs { + mt_queue _queue; + // std::vector> _visited; + std::vector _visited; // Don't use atomic operations. See mark_visited(). + simple_thread_pool> _pool; + static void noprogress(amer_t i) {} -struct is_in_set { - const std::unordered_set& set; - is_in_set(const std::unordered_set& s) : set(s) {} - bool operator()(mer_t m) const { return set.find(m) != set.cend(); } -}; + symm_bfs(int ths) + : _queue(mer_ops::nb_mers) + , _visited(mer_ops::nb_mers) + , _pool(ths) + {} + ~symm_bfs() { _pool.stop(); } -struct can_is_in_set { - const std::unordered_set& set; - can_is_in_set(const std::unordered_set& s) : set(s) {} - bool operator()(mer_t m) const { return set.find(mer_ops::canonical(m)) != set.cend(); } -}; + template + bool has_cycle(Fn in_set, amer_t m) { + // std::fill(_visited.begin(), _visited.end(), 0); + std::memset(_visited.data(), 0, _visited.size() * sizeof(decltype(_visited)::value_type)); + _queue.clear(); + volatile bool found_loop = false; + // The reverse complement of m is also considered removed from set and a + // loop involving rcm also triggers returning true. + const auto rcm = m.reverse_comp(); + + // Process one when starting from m. Consider rcm not part of the set.n + auto process_level = [&](int) { + std::pair push_loc{nullptr, 0}; + ssize_t push_index = 0; + + while(!found_loop) { + const auto slice = _queue.multi_pop(); + if(slice.second <= 0) break; // Finished queue of current level + + // Slice of length slice.second or ends with sentinel value m + for(ssize_t i = 0; i < slice.second && slice.first[i] != m; ++i) { + amer_t::mer_rc_pair nmer_rc(slice.first[i].nmer(0)); + for(unsigned b = 0; b < mer_ops::alpha; ++b, ++nmer_rc) { + if(mark_visited(nmer_rc.mer)) { + const bool is_in_set = (nmer_rc.mer != rcm) && in_set(nmer_rc); + if(is_in_set) continue; // Ignore mers in_set + if(push_index >= push_loc.second) { + push_loc = _queue.multi_push(); + push_index = 0; + } + push_loc.first[push_index] = nmer_rc.mer; + ++push_index; + } else if(nmer_rc.mer == m) { + found_loop = true; // Loop involving m or rcm + break; + } + } + } + } -struct is_in_union { - const std::unordered_set& set; - is_in_union(const std::unordered_set& s) : set(s) {} - bool operator()(mer_t m) const { return set.find(m) != set.cend() || set.find(mer_ops::reverse_comp(m)) != set.cend(); } -}; + // Padd unfilled location with sentinel value m + if(push_loc.first && push_index < push_loc.second) + push_loc.first[push_index] = m; + }; + _pool.set_work(process_level); + + // Prime queue. Simpler but equivalent to process_level + mark_visited(m); + amer_t::mer_rc_pair nmer_rc(m.nmer(0)); + for(unsigned b = 0; b < mer_ops::alpha; ++b, ++nmer_rc) { + if(mark_visited(nmer_rc.mer)) { + if((nmer_rc.mer != rcm) && in_set(nmer_rc)) continue; + _queue.push(nmer_rc.mer); + } else if(nmer_rc.mer == m) { + return true; + } + } + _queue.swap(); -// True if mer m or its reverse complement is in s1, but neither are in s2. -struct is_in_opt { - const std::unordered_set& set1; - const std::unordered_set& set2; - is_in_opt(const std::unordered_set& s1, const std::unordered_set& s2) - : set1(s1) - , set2(s2) - {} - bool operator()(mer_t m) const { - const auto rcm = mer_ops::reverse_comp(m); - return (set1.find(m) != set1.cend() || set1.find(rcm) != set1.cend()) && \ - (set2.find(m) == set2.end() && set2.find(rcm) == set2.end()); + while(!_queue.current_empty() && !found_loop) { + _pool.start(); + _queue.swap(); + } + + return found_loop; + } + + // Mark node m as _visited. Returns true if not previously visited. I.e., + // this call is the one who changed it to visited. + bool mark_visited(const amer_t& m) { + // Don't use any atomic operations to save time, although this is not + // strictly correct for a BFS. Meaning a node could be visited multiple + // times. This is rare. More importantly, it may add a bit of useless + // work but it doesn't affect the correctness. Overall it is worth it. + const auto prev = _visited[m.val]; + _visited[m.val] = 1; + return prev == 0; + // return _visited[m.val].exchange(1) == 0; } }; -template -R to_canonical_set(const C& mers) { - R res; - for(const auto m : mers) { - const auto& rc = mer_ops::reverse_comp(m); - if(rc < m) continue; - res.insert(m); - res.insert(rc); - } +struct is_in_set { + const std::unordered_set& set; + is_in_set(const std::unordered_set& s) : set(s) {} + bool operator()(amer_t m) const { return set.find(m) != set.cend(); } +}; - return res; -} +template +struct is_in_union { + const S& set; + is_in_union(const S& s) : set(s) {} + bool operator()(amer_t m) const { + return set.find(m) != set.cend() || set.find(m.reverse_comp()) != set.cend(); + } + bool operator()(const amer_t::mer_rc_pair& pair) const { + return set.find(pair.mer) || set.find(pair.rc); + } +}; template size_t canonicalize_size(const C& mers) { size_t size = 0; - for(const auto m : mers) { - const auto rcm = mer_ops::reverse_comp(m); + for(const auto& m : mers) { + const auto rcm = m.reverse_comp(); // Add two for canonical k-mers (1 for itself, 1 for its rc), unless it is self rc (then add only 1) - if(m <= rcm) { + if(m < rcm || m == rcm) { ++size; if(m < rcm) ++size; @@ -89,131 +166,140 @@ size_t canonicalize_size(const C& mers) { return size; } -template -size_t union_size(const C& mers, const R* remove) { +template +size_t union_size(const C& mers, const S& set) { size_t size = 0; - for(const auto m : mers) { - if(remove && remove->find(m) != remove->end()) continue; + for(const auto& m : mers) { + if(set.find(m) == set.end()) continue; ++size; // Add one for the k-mer itself - const auto rcm = mer_ops::reverse_comp(m); - if(m != rcm && mers.find(rcm) == mers.end()) + const auto rcm = m.reverse_comp(); + if(m != rcm && set.find(rcm) == set.end()) ++size; // Add one for its rc if not in set } return size; } -// Does a BFS to detect a new cycle in the de Bruijn graph minus a set. Starts -// from m and the reverse complement of m (rcm) and check for a loop back to m -// or back to rcm. -struct symm_bfs { - std::queue queue; - std::vector visited; +// set as bitset for quick membership +template +struct quickset { + typedef amer_t value_type; + std::bitset* _data; + quickset() + : _data(new std::bitset) + {} + ~quickset() { delete _data; } - static void noprogress(mer_t i) {} + void set(const amer_t& x) { _data->set(x.val); } + void erase(const amer_t& x) { _data->reset(x.val); } - symm_bfs() : visited(mer_ops::nb_mers) {} + bool find(const amer_t& x) const { return _data->test(x.val); } + constexpr bool end() const { return false; } + constexpr bool cend() const { return false; } +}; - template - bool has_cycle(Fn in_set, mer_t m) { - std::cout << "Has cycle " << (size_t)m << std::endl; - const auto rcm = mer_ops::reverse_comp(m); - std::cout << "Clear" << std::endl; - std::fill(visited.begin(), visited.end(), false); - clear_queue(); - - size_t nb_visited = 0; - queue.push(m); - while(true) { // Repeat with rcm eventually - while(!queue.empty()) { - if(nb_visited % 1000 == 0) - std::cout << '\r' << nb_visited << std::flush; - const auto current = queue.front(); - queue.pop(); - visited[current] = true; - ++nb_visited; - - for(unsigned b = 0; b < mer_ops::alpha; ++b) { - const auto nmer = mer_ops::nmer(current, b); - if(in_set(nmer)) continue; // Ignore mers in_set - if(!visited[nmer]) { - queue.push(nmer); - } else if(nmer == m || nmer == rcm) { - return true; // Loop involving rcm - } - } - } +namespace +{ + volatile std::sig_atomic_t terminate = 0; +} + +void signal_handler(int signal) +{ + terminate = 1; +} - if(visited[rcm]) break; - queue.push(rcm); - } +template +struct amain { + int operator()(const opt_canon& args) { + std::cerr << "Problem size too big" << std::endl; + return EXIT_FAILURE; + } +}; - return false; - } +template +struct amain { + int operator()(const opt_canon& args) { + auto prg = seeded_prg(args.oseed_given ? args.oseed_arg : nullptr, + args.iseed_given ? args.iseed_arg : nullptr); - void clear_queue() { - while(!queue.empty()) queue.pop(); - } -}; + // Install a signal handler so the computation can be stopped at any time + std::signal(SIGINT, signal_handler); + std::signal(SIGTERM, signal_handler); -// Greedy optimization procedure for mer_set using the random order. If -// can_super is true, result is a super set of the canonicalized mer_set. -struct greedy_opt { - // tarjan_scc comp_scc; - symm_bfs bfs; - std::unordered_set remove; - - template - std::pair optimize(const C& mer_set, const std::vector& order, bool can_super, uint64_t max_iteration = 0) { - remove.clear(); - // std::cout << "counts_orig" << std::endl; - // const auto counts_orig = comp_scc.scc_counts(is_in_set(mer_set)); - - uint64_t iteration = 0; - for(const auto m : order) { - if(can_super && m <= mer_ops::reverse_comp(m)) continue; // Must it be a super-set of canonical - - // Try adding m to remove. If increase SCCs, don't keep it - remove.insert(m); - const is_in_opt opt(mer_set, remove); - std::cout << "remove " << (size_t)m << std::endl; - // const auto counts = comp_scc.scc_counts(opt); - // if(counts.first > counts_orig.first || counts.second > counts_orig.second) - // remove.erase(m); - if(bfs.has_cycle(opt, m)) - remove.erase(m); - - ++iteration; - std::cout << "iteration " << iteration << std::endl; - if(max_iteration > 0 && iteration >= max_iteration) break; - } - return std::make_pair(0, 0); - } -}; -int main(int argc, char* argv[]) { - opt_canon args(argc, argv); + auto orig_set = get_mds>(args.sketch_file_arg, args.sketch_arg); + //auto order = get_mds>(args.sketch_file_arg, args.sketch_arg); + std::vector order(orig_set.cbegin(), orig_set.cend()); + std::shuffle(order.begin(), order.end(), prg); - auto prg = seeded_prg(args.oseed_given ? args.oseed_arg : nullptr, - args.iseed_given ? args.iseed_arg : nullptr); + quickset mer_set; + for(const auto& m : order) + mer_set.set(m); - std::cout << "Read set" << std::endl; - const auto mer_set = get_mds>(args.sketch_file_arg, args.sketch_arg); - std::cout << "Shuffle set" << std::endl; - std::vector order(mer_set.begin(), mer_set.end()); - std::shuffle(order.begin(), order.end(), prg); + is_in_union union_set(mer_set); + size_t removed = 0; - greedy_opt optimizer; + int nb_threads = args.threads_arg > std::thread::hardware_concurrency() ? std::thread::hardware_concurrency() : args.threads_arg; + symm_bfs bfs(nb_threads); + std::cout << "original set: " << order.size() + << "\ncanonicalized set: " << canonicalize_size(order) + << "\nunion set: " << union_size(order, mer_set) << '\n'; + const auto begin = std::chrono::steady_clock::now(); - const auto counts = optimizer.optimize(mer_set, order, true, args.iteration_given? args.iteration_arg : 0); - std::cout << "set\tsize\tsccs\n" - << "orig\t" << mer_set.size() << '\t' << counts << '\n' - << "union\t" << union_size(mer_set, (std::set*)nullptr) /* << '\t' << optimizer.comp_scc.scc_counts(is_in_union(mer_set)) */ << '\n' - << "canon\t" << canonicalize_size(mer_set) /* << '\t' << optimizer.comp_scc.scc_counts(can_is_in_set(mer_set)) */ << '\n' - << "super\t" << union_size(mer_set, &optimizer.remove) /* << '\t' << optimizer.comp_scc.scc_counts(is_in_opt(mer_set, optimizer.remove)) */ << '\n'; + size_t progress = 0; + const auto progress_suffix = isatty(1) ? '\r' : '\n'; - optimizer.optimize(mer_set, order, false, args.iteration_given? args.iteration_arg : 0); - std::cout << "opt\t" << union_size(mer_set, &optimizer.remove) /* << '\t' << optimizer.comp_scc.scc_counts(is_in_opt(mer_set, optimizer.remove)) */ << '\n'; + for(const auto& m : order) { + if(terminate) break; + if(args.progress_flag) { + std::cout << progress << ' ' << removed << ' ' + << (progress / (1e-6 + std::chrono::duration_cast(std::chrono::steady_clock::now() - begin).count())) + << progress_suffix << progress_suffix << std::flush; + ++progress; + } + + const auto rcm = m.reverse_comp(); + // Must be a super set of canonicalize + if(args.can_flag && (m < rcm || m == rcm)) continue; + // Skip if rcm is also in set and not the canonical k-mer (avoid double computation) + if(mer_set.find(rcm) != mer_set.end() && rcm < m) continue; + const bool has_cycle = bfs.has_cycle(union_set, m) || bfs.has_cycle(union_set, rcm); + if(!has_cycle) { + ++removed; + mer_set.erase(m); + mer_set.erase(rcm); + } + } + if(progress) std::cout << '\n'; + std::cout << "Removed " << removed << '/' << progress << "\nunion set: " << union_size(order, mer_set) << '\n'; + + if(args.output_given) { + std::ofstream out(args.output_arg); + bool first = true; + for(mer_t i = 0; i < mer_ops::nb_mers; ++i) { + if(!mer_set._data->test(i)) continue; + if(!first) { + out << ','; + first = false; + } + out << amer_t(i); + if(!out.good()) break; + } + out.close(); + if(!out.good()) { + std::cerr << "Error while writing set to '" << args.output_arg << "''" << std::endl; + return EXIT_FAILURE; + } + } + + return EXIT_SUCCESS; + } +}; + + + +int main(int argc, char* argv[]) { + opt_canon args(argc, argv); - return EXIT_SUCCESS; +return amain()(args); } diff --git a/opt_canon.yaggo b/opt_canon.yaggo index 5095dad..797b805 100644 --- a/opt_canon.yaggo +++ b/opt_canon.yaggo @@ -1,10 +1,25 @@ -purpose "Find set one less than union" +description "test" option('f', 'sketch-file') { description 'File with sketch mer set' c_string; typestr 'path' } +option('c', 'can') { + description 'Must be a super set of the canonicalized set' + flag; off +} + +option('t', 'threads') { + description 'Number of threads' + uint32; typestr 'N'; default 1 +} + +option('o', 'output') { + description 'Output created set' + c_string; typestr 'path' +} + option('s', 'iseed') { description 'Input seed file' typestr 'path' @@ -17,9 +32,9 @@ option('S', 'oseed') { c_string } -option('i', 'iteration') { - description 'Maximum number of optimization iterations' - uint64 +option('p', 'progress') { + description 'Show progress' + flag; off } arg('sketch') { diff --git a/simple_thread_pool.hpp b/simple_thread_pool.hpp new file mode 100644 index 0000000..b141531 --- /dev/null +++ b/simple_thread_pool.hpp @@ -0,0 +1,63 @@ +#ifndef SIMPLE_THREAD_POOL_H_ +#define SIMPLE_THREAD_POOL_H_ + +#include +#include +#include +#include +#include + +template +class simple_thread_pool { +protected: + std::barrier> _wait_barrier, _done_barrier; + std::vector _ths; + volatile bool _done; + Fn _work; + + void run_thread(unsigned index) { + while(true) { + _wait_barrier.arrive_and_wait(); + if(_done) break; + try { + _work(index); + } catch(...) { + std::cerr << "Pool work ended with an exception!"; + } + _done_barrier.arrive_and_wait(); + } + } + +public: + simple_thread_pool(unsigned nb_threads) + : _wait_barrier(nb_threads + 1, [](){}) // Barriers, no completion function + , _done_barrier(nb_threads + 1, [](){}) + , _done(false) + { + for(unsigned i = 0; i < nb_threads; ++i) + _ths.push_back(std::thread(&simple_thread_pool::run_thread, std::ref(*this), i)); + } + + void set_work(Fn fn) { _work = fn; } + + // Must be called when no thread is currently working. I.e., before calling + // start(), or after start() has completed. Wait for the threads to + // terminate. + void stop() { + _done = true; + _wait_barrier.arrive_and_wait(); + for(auto& th : _ths) + th.join(); + } + + // Start working by all the threads in the pool (i.e., execute the _work + // function). Returns when the work is done (all the threads finished + // executing _work). Concurrent call to start() are not supported. + void start() { + _wait_barrier.arrive_and_wait(); // Let all the thread start working + _done_barrier.arrive_and_wait(); // Wait for the threads to be done working + } +}; + + +#endif // SIMPLE_THREAD_POOL_H_ diff --git a/sketch_components.cc b/sketch_components.cc index c42f2f1..fb929ff 100644 --- a/sketch_components.cc +++ b/sketch_components.cc @@ -6,17 +6,11 @@ #include #include #include -#include -#include -#include #include #include -#include - -#include "misc.hpp" -#include "common.hpp" -#include "sketch_components.hpp" -#include "tarjan_scc.hpp" +#include +#include +#include #ifndef K #error Must define k-mer length K @@ -26,13 +20,18 @@ #error Must define alphabet length ALPHA #endif +#include "typename.hpp" +#include "misc.hpp" +#include "common.hpp" +#include "sketch_components.hpp" +#include "tarjan_scc.hpp" #include "mer_op.hpp" typedef mer_op_type mer_ops; typedef mer_ops::mer_t mer_t; // Function checking if a mer is in the set of selecting mers. -typedef bool (*in_set_fn)(mer_t); +// typedef bool (*in_set_fn)(mer_t); struct is_in_set { const std::unordered_set& set; @@ -54,43 +53,50 @@ struct is_in_union { int main(int argc, char* argv[]) { sketch_components args(argc, argv); - const auto mer_set = get_mds>(args.sketch_file_arg, args.sketch_arg); - - mer_t components = 0, in_components = 0, visited = 0; - size_t updates = 0; - auto progress = [&]() { - if(args.progress_flag) { - if(updates % 1024 == 0) { - std::cerr << '\r' - << "comps " << std::setw(10) << components - << " in_comps " << std::setw(10) << in_components - << " visited " << std::setw(10) << visited - << ' ' << std::setw(5) << std::fixed << std::setprecision(1) << (100.0 * (double)visited / mer_ops::nb_mers) << '%' - << std::flush; + + if constexpr(mer_ops::ak_bits > mer_ops::max_bits) { + std::cerr << "Problem size too big" << std::endl; + return EXIT_FAILURE; + } else { + const auto mer_set = get_mds>(args.sketch_file_arg, args.sketch_arg); + + mer_t components = 0, in_components = 0, visited = 0; + size_t updates = 0; + auto progress = [&]() { + if(args.progress_flag) { + if(updates % 1024 == 0) { + std::cerr << '\r' + << "comps " << std::setw(10) << components + << " in_comps " << std::setw(10) << in_components + << " visited " << std::setw(10) << visited + << ' ' << std::setw(5) << std::fixed << std::setprecision(1) << (100.0 * (double)visited / mer_ops::nb_mers) << '%' + << std::flush; + } + ++updates; } - ++updates; + }; + auto new_scc = [&components,&progress](mer_t m) { ++components; progress(); }; + auto new_node = [&in_components,&progress](mer_t m) { ++in_components; progress(); }; + auto new_visit = [&visited,&progress](mer_t m) { ++visited; progress(); }; + // static_assert(std::is_integral::value, "mer_t is not integral"); + + tarjan_scc comp_scc; + if(args.canonical_flag) { + const can_is_in_set can_fn(mer_set); + comp_scc.scc_iterate(can_fn, new_scc, new_node, new_visit); + } else if(args.union_flag) { + const is_in_union union_fn(mer_set); + comp_scc.scc_iterate(union_fn, new_scc, new_node, new_visit); + } else { + const is_in_set set_fn(mer_set); + comp_scc.scc_iterate(set_fn, new_scc, new_node, new_visit); } - }; - auto new_scc = [&components,&progress](mer_t m) { ++components; progress(); }; - auto new_node = [&in_components,&progress](mer_t m) { ++in_components; progress(); }; - auto new_visit = [&visited,&progress](mer_t m) { ++visited; progress(); }; - - tarjan_scc comp_scc; - if(args.canonical_flag) { - const can_is_in_set can_fn(mer_set); - comp_scc.scc_iterate(can_fn, new_scc, new_node, new_visit); - } else if(args.union_flag) { - const is_in_union union_fn(mer_set); - comp_scc.scc_iterate(union_fn, new_scc, new_node, new_visit); - } else { - const is_in_set set_fn(mer_set); - comp_scc.scc_iterate(set_fn, new_scc, new_node, new_visit); - } - if(args.progress_flag) - std::cerr << std::endl; + if(args.progress_flag) + std::cerr << std::endl; - std::cout << (size_t)components << ',' << (size_t)in_components << '\n'; + std::cout << (size_t)components << ',' << (size_t)in_components << '\n'; - return EXIT_SUCCESS; + return EXIT_SUCCESS; + } }