Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Suppress vl visits rebased #864

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cpp/program/setup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,15 @@ vector<SearchParams> Setup::loadParams(
if(cfg.contains("numVirtualLossesPerThread"+idxStr)) params.numVirtualLossesPerThread = cfg.getDouble("numVirtualLossesPerThread"+idxStr, 0.01, 1000.0);
else if(cfg.contains("numVirtualLossesPerThread")) params.numVirtualLossesPerThread = cfg.getDouble("numVirtualLossesPerThread", 0.01, 1000.0);
else params.numVirtualLossesPerThread = 1.0;
if(cfg.contains("suppressVirtualLossExploreFactor"+idxStr)) params.suppressVirtualLossExploreFactor = cfg.getDouble("suppressVirtualLossExploreFactor"+idxStr, 1.0, 1e10);
else if(cfg.contains("suppressVirtualLossExploreFactor")) params.suppressVirtualLossExploreFactor = cfg.getDouble("suppressVirtualLossExploreFactor", 1.0, 1e10);
else params.suppressVirtualLossExploreFactor = 1e10;
if(cfg.contains("suppressVirtualLossHindsight"+idxStr)) params.suppressVirtualLossHindsight = cfg.getBool("suppressVirtualLossHindsight"+idxStr);
else if(cfg.contains("suppressVirtualLossHindsight")) params.suppressVirtualLossHindsight = cfg.getBool("suppressVirtualLossHindsight");
else params.suppressVirtualLossHindsight = false;
if(cfg.contains("suppressVirtualLossLeakCatchUp"+idxStr)) params.suppressVirtualLossLeakCatchUp = cfg.getBool("suppressVirtualLossLeakCatchUp"+idxStr);
else if(cfg.contains("suppressVirtualLossLeakCatchUp")) params.suppressVirtualLossLeakCatchUp = cfg.getBool("suppressVirtualLossLeakCatchUp");
else params.suppressVirtualLossLeakCatchUp = false;

if(cfg.contains("treeReuseCarryOverTimeFactor"+idxStr)) params.treeReuseCarryOverTimeFactor = cfg.getDouble("treeReuseCarryOverTimeFactor"+idxStr,0.0,1.0);
else if(cfg.contains("treeReuseCarryOverTimeFactor")) params.treeReuseCarryOverTimeFactor = cfg.getDouble("treeReuseCarryOverTimeFactor",0.0,1.0);
Expand Down
116 changes: 79 additions & 37 deletions cpp/search/search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -554,8 +554,8 @@ void Search::runWholeSearch(
upperBoundVisitsLeft = std::min(upperBoundVisitsLeft, (double)maxPlayouts - numPlayouts);
upperBoundVisitsLeft = std::min(upperBoundVisitsLeft, (double)maxVisits - numPlayouts - numNonPlayoutVisits);

bool finishedPlayout = runSinglePlayout(*stbuf, upperBoundVisitsLeft);
if(finishedPlayout) {
PlayoutResult playoutResult = runSinglePlayout(*stbuf, upperBoundVisitsLeft);
if(playoutResult == PLAYOUT_SUCCESS) {
numPlayouts = numPlayoutsShared.fetch_add((int64_t)1, std::memory_order_relaxed);
numPlayouts += 1;
}
Expand Down Expand Up @@ -747,7 +747,7 @@ void Search::beginSearch(bool pondering) {
node.statsLock.clear(std::memory_order_release);

//Update all other stats
recomputeNodeStats(node, dummyThread, 0, true);
recomputeNodeStats(node, dummyThread, true);
}
}

Expand Down Expand Up @@ -974,8 +974,12 @@ void Search::recursivelyRecomputeStats(SearchNode& n) {
//and has 0 visits because we began a search and then stopped it before any playouts happened.
//In that case, there's not much to recompute.
if(weightSum <= 0.0) {
assert(numVisits == 0);
assert(isRoot);
//It's also possible that a suppressed virtual loss edge visit on a multi-move chain
//causes the parent to have 0 visits... somehow???
if(searchParams.suppressVirtualLossExploreFactor >= 1e10) {
assert(numVisits == 0);
assert(isRoot);
}
}
else {
double resultUtility = getResultUtility(winLossValueAvg, noResultValueAvg);
Expand All @@ -992,7 +996,7 @@ void Search::recursivelyRecomputeStats(SearchNode& n) {
}
else {
//Otherwise recompute it using the usual method
recomputeNodeStats(*node, thread, 0, isRoot);
recomputeNodeStats(*node, thread, isRoot);
}
};

Expand Down Expand Up @@ -1068,12 +1072,11 @@ void Search::computeRootValues() {
}
}


bool Search::runSinglePlayout(SearchThread& thread, double upperBoundVisitsLeft) {
PlayoutResult Search::runSinglePlayout(SearchThread& thread, double upperBoundVisitsLeft) {
//Store this value, used for futile-visit pruning this thread's root children selections.
thread.upperBoundVisitsLeft = upperBoundVisitsLeft;

bool finishedPlayout = playoutDescend(thread,*rootNode,true);
PlayoutResult playoutResult = playoutDescend(thread,*rootNode,true);

//Restore thread state back to the root state
thread.pla = rootPla;
Expand All @@ -1082,10 +1085,10 @@ bool Search::runSinglePlayout(SearchThread& thread, double upperBoundVisitsLeft)
thread.graphHash = rootGraphHash;
thread.graphPath.clear();

return finishedPlayout;
return playoutResult;
}

bool Search::playoutDescend(
PlayoutResult Search::playoutDescend(
SearchThread& thread, SearchNode& node,
bool isRoot
) {
Expand All @@ -1107,7 +1110,7 @@ bool Search::playoutDescend(
double lead = 0.0;
double weight = (searchParams.useUncertainty && nnEvaluator->supportsShorttermError()) ? searchParams.uncertaintyMaxWeight : 1.0;
addLeafValue(node, winLossValue, noResultValue, scoreMean, scoreMeanSq, lead, weight, true, false);
return true;
return PLAYOUT_SUCCESS;
}
else {
double winLossValue = 2.0 * ScoreValue::whiteWinsOfWinner(thread.history.winner, searchParams.drawEquivalentWinsForWhite) - 1;
Expand All @@ -1117,7 +1120,7 @@ bool Search::playoutDescend(
double lead = scoreMean;
double weight = (searchParams.useUncertainty && nnEvaluator->supportsShorttermError()) ? searchParams.uncertaintyMaxWeight : 1.0;
addLeafValue(node, winLossValue, noResultValue, scoreMean, scoreMeanSq, lead, weight, true, false);
return true;
return PLAYOUT_SUCCESS;
}
}

Expand All @@ -1129,25 +1132,25 @@ bool Search::playoutDescend(
//Leave the node as unevaluated - only the thread that first actually set the nnOutput into the node
//gets to update the state, to avoid races where we update the state while the node stats aren't updated yet.
if(!suc)
return false;
return PLAYOUT_FAILED;
}

bool suc = node.state.compare_exchange_strong(nodeState, SearchNode::STATE_EVALUATING, std::memory_order_seq_cst);
if(!suc) {
//Presumably someone else got there first.
//Just give up on this playout and try again from the start.
return false;
return PLAYOUT_FAILED;
}
else {
//Perform the nn evaluation and finish!
node.initializeChildren();
node.state.store(SearchNode::STATE_EXPANDED0, std::memory_order_seq_cst);
return true;
return PLAYOUT_SUCCESS;
}
}
else if(nodeState == SearchNode::STATE_EVALUATING) {
//Just give up on this playout and try again from the start.
return false;
return PLAYOUT_FAILED;
}

assert(nodeState >= SearchNode::STATE_EXPANDED0);
Expand All @@ -1157,10 +1160,12 @@ bool Search::playoutDescend(
int numChildrenFound;
int bestChildIdx;
Loc bestChildMoveLoc;
bool suppressEdgeVisit;
double suppressEdgeVisitUtilityThreshold;

SearchNode* child = NULL;
while(true) {
selectBestChildToDescend(thread,node,nodeState,numChildrenFound,bestChildIdx,bestChildMoveLoc,isRoot);
selectBestChildToDescend(thread,node,nodeState,numChildrenFound,bestChildIdx,bestChildMoveLoc,suppressEdgeVisit,suppressEdgeVisitUtilityThreshold,isRoot);

//The absurdly rare case that the move chosen is not legal
//(this should only happen either on a bug or where the nnHash doesn't have full legality information or when there's an actual hash collision).
Expand Down Expand Up @@ -1189,7 +1194,7 @@ bool Search::playoutDescend(

//As isReInit is true, we don't return, just keep going, since we didn't count this as a true visit in the node stats
nodeState = node.state.load(std::memory_order_acquire);
selectBestChildToDescend(thread,node,nodeState,numChildrenFound,bestChildIdx,bestChildMoveLoc,isRoot);
selectBestChildToDescend(thread,node,nodeState,numChildrenFound,bestChildIdx,bestChildMoveLoc,suppressEdgeVisit,suppressEdgeVisitUtilityThreshold,isRoot);

if(bestChildIdx >= 0) {
//New child
Expand All @@ -1198,7 +1203,7 @@ bool Search::playoutDescend(
//against someone reInitializing the output to add dirichlet noise or something, who was doing so based on an older cached
//nnOutput that still had the illegal move. If so, then just fail this playout and try again.
if(!thread.history.isLegal(thread.board,bestChildMoveLoc,thread.pla))
return false;
return PLAYOUT_FAILED;
}
//Existing child
else {
Expand All @@ -1209,7 +1214,7 @@ bool Search::playoutDescend(
assert(childrenCapacity > bestChildIdx);
(void)childrenCapacity;
children[bestChildIdx].addEdgeVisits(1);
return true;
return PLAYOUT_SUCCESS;
}
}
}
Expand All @@ -1218,7 +1223,7 @@ bool Search::playoutDescend(
//This might happen if all moves have been forbidden. The node will just get stuck counting visits without expanding
//and we won't do any search.
addCurrentNNOutputAsLeafValue(node,false);
return true;
return PLAYOUT_SUCCESS;
}

//Do we think we are searching a new child for the first time?
Expand Down Expand Up @@ -1275,16 +1280,21 @@ bool Search::playoutDescend(
//Even if the node was newly allocated, no need to delete the node, it will get cleaned up next time we mark and sweep the node table later.
//Clean up virtual losses in case the node is a transposition and is being used.
child->virtualLosses.fetch_add(-1,std::memory_order_release);
return false;
return PLAYOUT_FAILED;
}
}

//If edge visits is too much smaller than the child's visits, we can avoid descending.
//Instead just add edge visits and treat that as a visit.
if(maybeCatchUpEdgeVisits(thread, node, child, nodeState, bestChildIdx)) {

if(maybeCatchUpEdgeVisits(thread, node, child, nodeState, bestChildIdx, suppressEdgeVisit)) {
if(suppressEdgeVisit) {
child->virtualLosses.fetch_add(-1,std::memory_order_release);
return PLAYOUT_FAILED;
}
updateStatsAfterPlayout(node,thread,isRoot);
child->virtualLosses.fetch_add(-1,std::memory_order_release);
return true;
return PLAYOUT_SUCCESS;
}
}
//Searching an existing child
Expand All @@ -1297,10 +1307,14 @@ bool Search::playoutDescend(

//If edge visits is too much smaller than the child's visits, we can avoid descending.
//Instead just add edge visits and treat that as a visit.
if(maybeCatchUpEdgeVisits(thread, node, child, nodeState, bestChildIdx)) {
if(maybeCatchUpEdgeVisits(thread, node, child, nodeState, bestChildIdx, suppressEdgeVisit)) {
if(suppressEdgeVisit) {
child->virtualLosses.fetch_add(-1,std::memory_order_release);
return PLAYOUT_FAILED;
}
updateStatsAfterPlayout(node,thread,isRoot);
child->virtualLosses.fetch_add(-1,std::memory_order_release);
return true;
return PLAYOUT_SUCCESS;
}

//Make the move!
Expand All @@ -1323,36 +1337,59 @@ bool Search::playoutDescend(
//No insertion, child was already there
if(!result.second) {
SearchNodeChildrenReference children = node.getChildren(nodeState);
if(suppressEdgeVisit) {
child->virtualLosses.fetch_add(-1,std::memory_order_release);
return PLAYOUT_FAILED;
}
children[bestChildIdx].addEdgeVisits(1);
updateStatsAfterPlayout(node,thread,isRoot);
child->virtualLosses.fetch_add(-1,std::memory_order_release);
return true;
return PLAYOUT_SUCCESS;
}
}

//Recurse!
bool finishedPlayout = playoutDescend(thread,*child,false);
PlayoutResult childPlayoutResult = playoutDescend(thread,*child,false);
PlayoutResult ourPlayoutResult = PLAYOUT_FAILED;

//Update this node stats
if(finishedPlayout) {
nodeState = node.state.load(std::memory_order_acquire);
SearchNodeChildrenReference children = node.getChildren(nodeState);
children[bestChildIdx].addEdgeVisits(1);
if(childPlayoutResult == PLAYOUT_NOINCREMENT || childPlayoutResult == PLAYOUT_SUCCESS) {
if(searchParams.suppressVirtualLossHindsight) {
double childUtilityAvg = node.stats.utilityAvg.load(std::memory_order_acquire);
if(node.nextPla == P_WHITE)
suppressEdgeVisit = childUtilityAvg < suppressEdgeVisitUtilityThreshold;
else
suppressEdgeVisit = childUtilityAvg > suppressEdgeVisitUtilityThreshold;
}

if(!suppressEdgeVisit) {
nodeState = node.state.load(std::memory_order_acquire);
SearchNodeChildrenReference children = node.getChildren(nodeState);
children[bestChildIdx].addEdgeVisits(1);
ourPlayoutResult = PLAYOUT_SUCCESS;
}
else {
ourPlayoutResult = PLAYOUT_NOINCREMENT;
}
updateStatsAfterPlayout(node,thread,isRoot);
}
child->virtualLosses.fetch_add(-1,std::memory_order_release);

return finishedPlayout;
return ourPlayoutResult;
}


//If edge visits is too much smaller than the child's visits, we can avoid descending.
//Instead just add edge visits and return immediately.
//Returns true if we do perform a catch up edge visit, OR if the child visits is already sufficient but suppressEdgeVisit
//is true. In other words, returns true when we can terminate the playout and false when we need to go deeper.
bool Search::maybeCatchUpEdgeVisits(
SearchThread& thread,
SearchNode& node,
SearchNode* child,
const SearchNodeState& nodeState,
const int bestChildIdx
const int bestChildIdx,
bool suppressEdgeVisit
) {
//Don't need to do this since we already are pretty recent as of finding the best child.
//nodeState = node.state.load(std::memory_order_acquire);
Expand All @@ -1373,15 +1410,20 @@ bool Search::maybeCatchUpEdgeVisits(
if(searchParams.graphSearchCatchUpLeakProb > 0.0 && edgeVisits < childVisits && thread.rand.nextBool(searchParams.graphSearchCatchUpLeakProb))
return false;

if(edgeVisits >= childVisits)
return false;
if(suppressEdgeVisit)
return !searchParams.suppressVirtualLossLeakCatchUp;

//If the edge visits exceeds the child then we need to search the child more, but as long as that's not the case,
//we can add more edge visits.
constexpr int64_t numToAdd = 1;
// int64_t numToAdd;
do {
while(!childPointer.compexweakEdgeVisits(edgeVisits, edgeVisits + numToAdd)) {
if(edgeVisits >= childVisits)
return false;
// numToAdd = std::min((childVisits - edgeVisits + 3) / 4, maxNumToAdd);
} while(!childPointer.compexweakEdgeVisits(edgeVisits, edgeVisits + numToAdd));
}

return true;
}
Loading