Skip to content

Commit 0e1c34d

Browse files
authored
Early exit if backend is idle and there is work to do. (#1503)
* Early exit gathering if all collisions and backend is idle. * Fix logic. * Even more aggressive. * Don't enable if threads=1. * Parameterise the behavior.
1 parent d2e03fd commit 0e1c34d

File tree

4 files changed

+42
-1
lines changed

4 files changed

+42
-1
lines changed

src/mcts/params.cc

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,16 @@ const OptionId SearchParams::kMinimumWorkPerTaskForProcessingId{
297297
"minimum-per-task-processing", "MinimumPerTaskProcessing",
298298
"Processing work won't be split into chunks smaller than this (unless its "
299299
"more than half of MinimumProcessingWork)."};
300+
const OptionId SearchParams::kIdlingMinimumWorkId{
301+
"idling-minimum-work", "IdlingMinimumWork",
302+
"Only early exit gathering due to 'idle' backend if more than this many "
303+
"nodes will be sent to the backend."};
304+
const OptionId SearchParams::kThreadIdlingThresholdId{
305+
"thread-idling-threshold", "ThreadIdlingThreshold",
306+
"If there are more than this number of search threads that are not "
307+
"actively in the process of either sending data to the backend or waiting "
308+
"for data from the backend, assume that the backend is idle."};
309+
300310
void SearchParams::Populate(OptionsParser* options) {
301311
// Here the uci optimized defaults" are set.
302312
// Many of them are overridden with training specific values in tournament.cc.
@@ -369,6 +379,8 @@ void SearchParams::Populate(OptionsParser* options) {
369379
options->Add<IntOption>(kMinimumRemainingWorkSizeForPickingId, 0, 100000) =
370380
20;
371381
options->Add<IntOption>(kMinimumWorkPerTaskForProcessingId, 1, 100000) = 8;
382+
options->Add<IntOption>(kIdlingMinimumWorkId, 0, 10000) = 0;
383+
options->Add<IntOption>(kThreadIdlingThresholdId, 0, 128) = 1;
372384

373385
options->HideOption(kNoiseEpsilonId);
374386
options->HideOption(kNoiseAlphaId);
@@ -449,7 +461,9 @@ SearchParams::SearchParams(const OptionsDict& options)
449461
kMinimumRemainingWorkSizeForPicking(
450462
options.Get<int>(kMinimumRemainingWorkSizeForPickingId)),
451463
kMinimumWorkPerTaskForProcessing(
452-
options.Get<int>(kMinimumWorkPerTaskForProcessingId)) {
464+
options.Get<int>(kMinimumWorkPerTaskForProcessingId)),
465+
kIdlingMinimumWork(options.Get<int>(kIdlingMinimumWorkId)),
466+
kThreadIdlingThreshold(options.Get<int>(kThreadIdlingThresholdId)) {
453467
if (std::max(std::abs(kDrawScoreSidetomove), std::abs(kDrawScoreOpponent)) +
454468
std::max(std::abs(kDrawScoreWhite), std::abs(kDrawScoreBlack)) >
455469
1.0f) {

src/mcts/params.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ class SearchParams {
128128
int GetMinimumWorkPerTaskForProcessing() const {
129129
return kMinimumWorkPerTaskForProcessing;
130130
}
131+
int GetIdlingMinimumWork() const { return kIdlingMinimumWork; }
132+
int GetThreadIdlingThreshold() const { return kThreadIdlingThreshold; }
131133

132134
// Search parameter IDs.
133135
static const OptionId kMiniBatchSizeId;
@@ -187,6 +189,8 @@ class SearchParams {
187189
static const OptionId kMinimumWorkSizeForPickingId;
188190
static const OptionId kMinimumRemainingWorkSizeForPickingId;
189191
static const OptionId kMinimumWorkPerTaskForProcessingId;
192+
static const OptionId kIdlingMinimumWorkId;
193+
static const OptionId kThreadIdlingThresholdId;
190194

191195
private:
192196
const OptionsDict& options_;
@@ -239,6 +243,8 @@ class SearchParams {
239243
const int kMinimumWorkSizeForPicking;
240244
const int kMinimumRemainingWorkSizeForPicking;
241245
const int kMinimumWorkPerTaskForProcessing;
246+
const int kIdlingMinimumWork;
247+
const int kThreadIdlingThreshold;
242248
};
243249

244250
} // namespace lczero

src/mcts/search.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,7 @@ EdgeAndNode Search::GetBestRootChildWithTemperature(float temperature) const {
795795
}
796796

797797
void Search::StartThreads(size_t how_many) {
798+
thread_count_.store(how_many, std::memory_order_release);
798799
Mutex::Lock lock(threads_mutex_);
799800
// First thread is a watchdog thread.
800801
if (threads_.size() == 0) {
@@ -1073,6 +1074,7 @@ void SearchWorker::ExecuteOneIteration() {
10731074
} else {
10741075
GatherMinibatch();
10751076
}
1077+
search_->backend_waiting_counter_.fetch_add(1, std::memory_order_relaxed);
10761078

10771079
// 2b. Collect collisions.
10781080
CollectCollisions();
@@ -1086,6 +1088,7 @@ void SearchWorker::ExecuteOneIteration() {
10861088

10871089
// 4. Run NN computation.
10881090
RunNNComputation();
1091+
search_->backend_waiting_counter_.fetch_add(-1, std::memory_order_relaxed);
10891092

10901093
// 5. Retrieve NN computations (and terminal values) into nodes.
10911094
FetchMinibatchResults();
@@ -1210,6 +1213,8 @@ void SearchWorker::GatherMinibatch2() {
12101213
// Number of nodes processed out of order.
12111214
number_out_of_order_ = 0;
12121215

1216+
int thread_count = search_->thread_count_.load(std::memory_order_acquire);
1217+
12131218
// Gather nodes to process in the current batch.
12141219
// If we had too many nodes out of order, also interrupt the iteration so
12151220
// that search can exit.
@@ -1218,6 +1223,20 @@ void SearchWorker::GatherMinibatch2() {
12181223
// If there's something to process without touching slow neural net, do it.
12191224
if (minibatch_size > 0 && computation_->GetCacheMisses() == 0) return;
12201225

1226+
// If there is backend work to be done, and the backend is idle - exit
1227+
// immediately.
1228+
// Only do this fancy work if there are multiple threads as otherwise we
1229+
// early exit from every batch since there is never another search thread to
1230+
// be keeping the backend busy. Which would mean that threads=1 has a
1231+
// massive nps drop.
1232+
if (thread_count > 1 && minibatch_size > 0 &&
1233+
computation_->GetCacheMisses() > params_.GetIdlingMinimumWork() &&
1234+
thread_count - search_->backend_waiting_counter_.load(
1235+
std::memory_order_relaxed) >
1236+
params_.GetThreadIdlingThreshold()) {
1237+
return;
1238+
}
1239+
12211240
int new_start = static_cast<int>(minibatch_.size());
12221241

12231242
PickNodesToExtend(

src/mcts/search.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ class Search {
193193
GUARDED_BY(counters_mutex_);
194194

195195
std::atomic<int> pending_searchers_{0};
196+
std::atomic<int> backend_waiting_counter_{0};
197+
std::atomic<int> thread_count_{0};
196198

197199
std::vector<std::pair<Node*, int>> shared_collisions_
198200
GUARDED_BY(nodes_mutex_);

0 commit comments

Comments
 (0)