Skip to content

Commit d2e03fd

Browse files
authored
Increase use of workspace. (#1498)
3% increase in nps on benchmark. 100% increase in nps in #1 positions.
1 parent e93a2a8 commit d2e03fd

File tree

3 files changed

+28
-11
lines changed

3 files changed

+28
-11
lines changed

src/chess/position.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ class PositionHistory {
9393
public:
9494
PositionHistory() = default;
9595
PositionHistory(const PositionHistory& other) = default;
96+
PositionHistory(PositionHistory&& other) = default;
97+
98+
PositionHistory& operator=(const PositionHistory& other) = default;
99+
PositionHistory& operator=(PositionHistory&& other) = default;
96100

97101
// Returns first position of the game (or fen from which it was initialized).
98102
const Position& Starting() const { return positions_.front(); }

src/mcts/search.cc

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,11 +1345,11 @@ void SearchWorker::GatherMinibatch2() {
13451345
}
13461346

13471347
void SearchWorker::ProcessPickedTask(int start_idx, int end_idx,
1348-
TaskWorkspace*) {
1348+
TaskWorkspace* workspace) {
1349+
auto& history = workspace->history;
13491350
// This code runs multiple passes of work across the same input in order to
13501351
// reduce taking/dropping mutexes in quick succession.
1351-
PositionHistory history = search_->played_history_;
1352-
history.Reserve(search_->played_history_.GetLength() + 30);
1352+
history = search_->played_history_;
13531353

13541354
// First pass - Extend nodes.
13551355
for (int i = start_idx; i < end_idx; i++) {
@@ -1491,15 +1491,15 @@ void SearchWorker::PickNodesToExtendTask(Node* node, int base_depth,
14911491
// with tasks.
14921492
// TODO: pre-reserve visits_to_perform for expected depth and likely maximum
14931493
// width. Maybe even do so outside of lock scope.
1494-
std::vector<std::unique_ptr<std::array<int, 256>>> visits_to_perform;
14951494
auto& vtp_buffer = workspace->vtp_buffer;
1496-
visits_to_perform.reserve(30);
1497-
std::vector<int> vtp_last_filled;
1498-
vtp_last_filled.reserve(30);
1499-
std::vector<int> current_path;
1500-
current_path.reserve(30);
1501-
std::vector<Move> moves_to_path = moves_to_base;
1502-
moves_to_path.reserve(30);
1495+
auto& visits_to_perform = workspace->visits_to_perform;
1496+
visits_to_perform.clear();
1497+
auto& vtp_last_filled = workspace->vtp_last_filled;
1498+
vtp_last_filled.clear();
1499+
auto& current_path = workspace->current_path;
1500+
current_path.clear();
1501+
auto& moves_to_path = workspace->moves_to_path;
1502+
moves_to_path = moves_to_base;
15031503
// Sometimes receiver is reused, othertimes not, so only jump start if small.
15041504
if (receiver->capacity() < 30) {
15051505
receiver->reserve(receiver->size() + 30);

src/mcts/search.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,19 @@ class SearchWorker {
379379
struct TaskWorkspace {
380380
std::array<Node::Iterator, 256> cur_iters;
381381
std::vector<std::unique_ptr<std::array<int, 256>>> vtp_buffer;
382+
std::vector<std::unique_ptr<std::array<int, 256>>> visits_to_perform;
383+
std::vector<int> vtp_last_filled;
384+
std::vector<int> current_path;
385+
std::vector<Move> moves_to_path;
386+
PositionHistory history;
387+
TaskWorkspace() {
388+
vtp_buffer.reserve(30);
389+
visits_to_perform.reserve(30);
390+
vtp_last_filled.reserve(30);
391+
current_path.reserve(30);
392+
moves_to_path.reserve(30);
393+
history.Reserve(30);
394+
}
382395
};
383396

384397
struct PickTask {

0 commit comments

Comments
 (0)