Skip to content

Commit 331ee24

Browse files
Merge branch 'LeelaChessZero:master' into master
2 parents 8286b55 + 1b685ff commit 331ee24

File tree

10 files changed

+169
-100
lines changed

10 files changed

+169
-100
lines changed

.circleci/config.yml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,14 @@ jobs:
99
name: Install build tools
1010
command: |
1111
apt-get update
12-
apt-get -y install git python3-pip gcc-10 g++-10 clang-12 zlib1g zlib1g-dev
12+
apt-get -y install git python3-pip gcc-10 g++-10 clang-12 zlib1g zlib1g-dev wget
1313
pip3 install meson==0.63
1414
pip3 install ninja
15+
- run:
16+
name: Install onnxruntime
17+
command: |
18+
wget https://github.com/microsoft/onnxruntime/releases/download/v1.22.0/onnxruntime-linux-x64-1.22.0.tgz -P /tmp
19+
tar xzf /tmp/onnxruntime-linux-x64-1.22.0.tgz -C /tmp
1520
- run:
1621
name: "Pull Submodules"
1722
command: git submodule update --init
@@ -20,13 +25,13 @@ jobs:
2025
environment:
2126
CC: gcc-10
2227
CXX: g++-10
23-
command: meson build-gcc -Dgtest=false
28+
command: meson build-gcc -Dgtest=false -Donnx_include=/tmp/onnxruntime-linux-x64-1.22.0/include -Donnx_libdir=/tmp/onnxruntime-linux-x64-1.22.0/lib
2429
- run:
2530
name: Meson Clang
2631
environment:
2732
CC: clang-12
2833
CXX: clang++-12
29-
command: meson build-clang -Dgtest=false -Db_lto=false
34+
command: meson build-clang -Dgtest=false -Db_lto=false -Donnx_include=/tmp/onnxruntime-linux-x64-1.22.0/include -Donnx_libdir=/tmp/onnxruntime-linux-x64-1.22.0/lib
3035
- run:
3136
name: Build GCC
3237
command: |

meson.build

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ if get_option('build_backends')
340340

341341
eigen_dep = dependency('eigen3')
342342
# Check for needed header, bad dependency seen in the widl.
343-
if eigen_dep.found() and cc.has_header('Eigen/Core')
343+
if eigen_dep.found() and cc.has_header('Eigen/Core', dependencies: eigen_dep)
344344
deps += eigen_dep
345345
else
346346
deps += subproject('eigen').get_variable('eigen_dep')

src/engine.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ void Engine::EnsureSearchStopped() {
162162
}
163163

164164
void Engine::UpdateBackendConfig() {
165+
LOGFILE << "Update backend configuration.";
165166
const std::string backend_name =
166167
options_.Get<std::string>(SharedBackendParams::kBackendId);
167168
if (!backend_ || backend_name != backend_name_ ||
@@ -182,6 +183,7 @@ void Engine::EnsureSyzygyTablebasesLoaded() {
182183
previous_tb_paths_ = tb_paths;
183184

184185
if (tb_paths.empty()) {
186+
LOGFILE << "Reset Syzygy tablebases.";
185187
syzygy_tb_.reset();
186188
} else {
187189
syzygy_tb_ = std::make_unique<SyzygyTablebase>();
@@ -198,6 +200,7 @@ void Engine::EnsureSyzygyTablebasesLoaded() {
198200
// Initializes the search with either the specified position for the normal
199201
// search or the position one ply trimmed for the ponder search.
200202
void Engine::InitializeSearchPosition(bool for_ponder) {
203+
LOGFILE << "Setting a new search position.";
201204
assert(last_position_);
202205
if (!for_ponder) {
203206
search_->SetPosition(*last_position_);

src/neural/backends/network_onnx.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@
2828
#include <algorithm>
2929
#include <cassert>
3030
#include <fstream>
31+
#include <iomanip>
3132
#include <iterator>
3233
#include <memory>
34+
#include <sstream>
3335
#include <string>
3436
#include <vector>
3537

@@ -333,9 +335,11 @@ Ort::SessionOptions OnnxNetwork::GetOptions(int gpu, int threads,
333335
trt_options["trt_min_subgraph_size"] = "1";
334336
trt_options["trt_engine_cache_enable"] = "1";
335337
// We need the batch size as well as the hash, as it is set after loading.
338+
std::ostringstream oss;
339+
oss << std::hex << hash;
336340
trt_options["trt_engine_cache_prefix"] =
337341
"Lc0_ONNX_TRT_ORT_" + Ort::GetVersionString() + "_batch_" +
338-
std::to_string(batch_size) + "_" + std::format("{:x}", hash) + "_";
342+
std::to_string(batch_size) + "_" + oss.str() + "_";
339343
trt_options["trt_engine_cache_path"] = cache_dir;
340344
trt_options["trt_timing_cache_enable"] = "1";
341345
trt_options["trt_timing_cache_path"] = cache_dir;

src/search/classic/search.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,7 +1105,7 @@ void SearchWorker::RunTasks(int tid) {
11051105
// We got the spin lock, double check we're still in the clear.
11061106
if (nta < tc) {
11071107
id = tasks_taken_.fetch_add(1, std::memory_order_acq_rel);
1108-
task = &picking_tasks_[id];
1108+
task = picking_tasks_.data() + id;
11091109
task_taking_started_.store(0, std::memory_order_release);
11101110
break;
11111111
}
@@ -1153,7 +1153,7 @@ void SearchWorker::RunTasks(int tid) {
11531153
break;
11541154
}
11551155
}
1156-
picking_tasks_[id].complete = true;
1156+
picking_tasks_.data()[id].complete = true;
11571157
completed_tasks_.fetch_add(1, std::memory_order_acq_rel);
11581158
}
11591159
}

src/search/dag_classic/node.cc

Lines changed: 79 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include <cmath>
3333
#include <cstring>
3434
#include <iostream>
35+
#include <fstream>
3536
#include <list>
3637
#include <sstream>
3738
#include <thread>
@@ -430,14 +431,71 @@ void Node::UnsetLowNode() {
430431
low_node_.reset();
431432
}
432433

434+
#ifndef NDEBUG
435+
namespace {
436+
static Node::VisitorId::storage current_visitor_id = 0;
437+
}
438+
439+
Node::VisitorId::VisitorId() {
440+
id_ = ++current_visitor_id;
441+
if (id_ == 0)
442+
id_ = ++current_visitor_id;
443+
}
444+
445+
Node::VisitorId::~VisitorId() {
446+
assert(current_visitor_id == id_);
447+
}
448+
449+
bool LowNode::Visit(Node::VisitorId::type id) {
450+
if (visitor_id_ == id)
451+
return false;
452+
visitor_id_ = id;
453+
return true;
454+
}
455+
456+
template<typename VisitorType, typename EdgeVisitorType>
457+
static void TreeWalk(const Node* node, bool as_opponent,
458+
Node::VisitorId::type id,
459+
VisitorType visitor, EdgeVisitorType edge) {
460+
const std::shared_ptr<LowNode>& low_node = node->GetLowNode();
461+
if (!low_node || !low_node->Visit(id)) {
462+
return;
463+
}
464+
465+
visitor(low_node.get(), as_opponent);
466+
467+
for (auto& child_edge : node->Edges()) {
468+
auto child = child_edge.node();
469+
if (child == nullptr) {
470+
break;
471+
}
472+
edge(child, as_opponent, low_node.get());
473+
}
474+
475+
for (auto& child_edge : node->Edges()) {
476+
auto child = child_edge.node();
477+
if (child == nullptr) {
478+
return;
479+
}
480+
TreeWalk(child, !as_opponent, id, visitor, edge);
481+
}
482+
}
483+
433484
static std::string PtrToNodeName(const void* ptr) {
434485
std::ostringstream oss;
435486
oss << "n_" << ptr;
436487
return oss.str();
437488
}
438489

439-
std::string LowNode::DotNodeString() const {
440-
std::ostringstream oss;
490+
template<typename VisitorType, typename EdgeVisitorType>
491+
static void TreeWalk(const Node* node, bool as_opponent,
492+
VisitorType visitor, EdgeVisitorType edge) {
493+
Node::VisitorId id{};
494+
edge(node, as_opponent, nullptr);
495+
TreeWalk(node, !as_opponent, id, visitor, edge);
496+
}
497+
498+
void LowNode::DotNodeString(std::ofstream& oss) const {
441499
oss << PtrToNodeName(this) << " ["
442500
<< "shape=box";
443501
// Adjust formatting to limit node size.
@@ -464,12 +522,10 @@ std::string LowNode::DotNodeString() const {
464522
<< "\\n\\nThis=" << this << "\\nEdges=" << edges_.get()
465523
<< "\\nNumEdges=" << static_cast<int>(num_edges_)
466524
<< "\\nChild=" << child_.get() << "\\n\"";
467-
oss << "];";
468-
return oss.str();
525+
oss << "];" << std::endl;
469526
}
470527

471-
std::string Node::DotEdgeString(bool as_opponent, const LowNode* parent) const {
472-
std::ostringstream oss;
528+
void Node::DotEdgeString(std::ofstream& oss, bool as_opponent, const LowNode* parent) const {
473529
oss << (parent == nullptr ? "top" : PtrToNodeName(parent)) << " -> "
474530
<< (low_node_ ? PtrToNodeName(low_node_.get()) : PtrToNodeName(this))
475531
<< " [";
@@ -493,15 +549,10 @@ std::string Node::DotEdgeString(bool as_opponent, const LowNode* parent) const {
493549
<< std::noshowpos //
494550
<< "\\nLowNode=" << low_node_.get() << "\\nParent=" << parent
495551
<< "\\nIndex=" << index_ << "\\nSibling=" << sibling_.get() << "\\n\"";
496-
oss << "];";
497-
return oss.str();
552+
oss << "];" << std::endl;
498553
}
499554

500-
std::string Node::DotGraphString(bool as_opponent) const {
501-
std::ostringstream oss;
502-
std::unordered_set<const LowNode*> seen;
503-
std::list<std::pair<const Node*, bool>> unvisited_fifo;
504-
555+
void Node::DotGraphString(std::ofstream& oss, bool as_opponent) const {
505556
oss << "strict digraph {" << std::endl;
506557
oss << "edge ["
507558
<< "headport=n"
@@ -514,83 +565,37 @@ std::string Node::DotGraphString(bool as_opponent) const {
514565
<< "];" << std::endl;
515566
oss << "ranksep=" << 4.0f * std::log10(GetN()) << std::endl;
516567

517-
oss << DotEdgeString(!as_opponent) << std::endl;
518-
if (low_node_) {
519-
seen.insert(low_node_.get());
520-
unvisited_fifo.push_back(std::pair(this, as_opponent));
521-
}
522-
523-
while (!unvisited_fifo.empty()) {
524-
auto [parent_node, parent_as_opponent] = unvisited_fifo.front();
525-
unvisited_fifo.pop_front();
526-
527-
auto parent_low_node = parent_node->GetLowNode().get();
528-
seen.insert(parent_low_node);
529-
oss << parent_low_node->DotNodeString() << std::endl;
530-
531-
for (auto& child_edge : parent_node->Edges()) {
532-
auto child = child_edge.node();
533-
if (child == nullptr) break;
534-
535-
oss << child->DotEdgeString(parent_as_opponent) << std::endl;
536-
auto child_low_node = child->GetLowNode().get();
537-
if (child_low_node != nullptr &&
538-
(seen.find(child_low_node) == seen.end())) {
539-
seen.insert(child_low_node);
540-
unvisited_fifo.push_back(std::pair(child, !parent_as_opponent));
541-
}
542-
}
543-
}
568+
TreeWalk(this, !as_opponent,
569+
[&](const LowNode* low_node, bool) {
570+
low_node->DotNodeString(oss);
571+
},
572+
[&](const Node* node, bool as_opponent, const LowNode* parent) {
573+
node->DotEdgeString(oss, as_opponent, parent);
574+
});
544575

545576
oss << "}" << std::endl;
546-
547-
return oss.str();
548577
}
549578

550579
bool Node::ZeroNInFlight() const {
551-
std::unordered_set<const LowNode*> seen;
552-
std::list<const Node*> unvisited_fifo;
553580
size_t nonzero_node_count = 0;
554-
555-
if (GetNInFlight() > 0) {
556-
std::cerr << DebugString() << std::endl;
557-
++nonzero_node_count;
558-
}
559-
if (low_node_) {
560-
seen.insert(low_node_.get());
561-
unvisited_fifo.push_back(this);
562-
}
563-
564-
while (!unvisited_fifo.empty()) {
565-
auto parent_node = unvisited_fifo.front();
566-
unvisited_fifo.pop_front();
567-
568-
for (auto& child_edge : parent_node->Edges()) {
569-
auto child = child_edge.node();
570-
if (child == nullptr) break;
571-
572-
if (child->GetNInFlight() > 0) {
573-
std::cerr << child->DebugString() << std::endl;
581+
TreeWalk(this, false,
582+
[](const LowNode*, bool) {},
583+
[&](const Node* node, bool, const LowNode*) {
584+
if (node->GetNInFlight() > 0) [[unlikely]] {
585+
CERR << node->DebugString() << std::endl;
574586
++nonzero_node_count;
575587
}
576-
577-
auto child_low_node = child->GetLowNode().get();
578-
if (child_low_node != nullptr &&
579-
(seen.find(child_low_node) == seen.end())) {
580-
seen.insert(child_low_node);
581-
unvisited_fifo.push_back(child);
582-
}
583-
}
584-
}
588+
});
585589

586590
if (nonzero_node_count > 0) {
587-
std::cerr << "GetNInFlight() is nonzero on " << nonzero_node_count
591+
CERR << "GetNInFlight() is nonzero on " << nonzero_node_count
588592
<< " nodes" << std::endl;
589593
return false;
590594
}
591595

592596
return true;
593597
}
598+
#endif
594599

595600
void Node::SortEdges() const {
596601
assert(low_node_);

0 commit comments

Comments
 (0)