Skip to content

Commit 4230dcb

Browse files
authored
Re-introduce double buffer in UpdatePosition, to fix perf regression in gpu_hist (#6757)
* Revert "gpu_hist performance tweaks (#5707)" This reverts commit f779980. * Address reviewer's comment * Fix build error
1 parent e2d8a99 commit 4230dcb

File tree

3 files changed

+63
-21
lines changed

3 files changed

+63
-21
lines changed

src/common/device_helpers.cuh

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,36 @@ class TemporaryArray {
549549
size_t size_;
550550
};
551551

552+
/**
553+
* \brief A double buffer, useful for algorithms like sort.
554+
*/
555+
template <typename T>
556+
class DoubleBuffer {
557+
public:
558+
cub::DoubleBuffer<T> buff;
559+
xgboost::common::Span<T> a, b;
560+
DoubleBuffer() = default;
561+
template <typename VectorT>
562+
DoubleBuffer(VectorT *v1, VectorT *v2) {
563+
a = xgboost::common::Span<T>(v1->data().get(), v1->size());
564+
b = xgboost::common::Span<T>(v2->data().get(), v2->size());
565+
buff = cub::DoubleBuffer<T>(a.data(), b.data());
566+
}
567+
568+
size_t Size() const {
569+
CHECK_EQ(a.size(), b.size());
570+
return a.size();
571+
}
572+
cub::DoubleBuffer<T> &CubBuffer() { return buff; }
573+
574+
T *Current() { return buff.Current(); }
575+
xgboost::common::Span<T> CurrentSpan() {
576+
return xgboost::common::Span<T>{buff.Current(), Size()};
577+
}
578+
579+
T *Other() { return buff.Alternate(); }
580+
};
581+
552582
/**
553583
* \brief Copies device span to std::vector.
554584
*

src/tree/gpu_hist/row_partitioner.cu

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,17 @@ void Reset(int device_idx, common::Span<RowPartitioner::RowIndexT> ridx,
103103
}
104104

105105
RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
106-
: device_idx_(device_idx), ridx_a_(num_rows), position_a_(num_rows) {
106+
: device_idx_(device_idx), ridx_a_(num_rows), position_a_(num_rows),
107+
ridx_b_(num_rows), position_b_(num_rows) {
107108
dh::safe_cuda(cudaSetDevice(device_idx_));
108-
Reset(device_idx, dh::ToSpan(ridx_a_), dh::ToSpan(position_a_));
109+
ridx_ = dh::DoubleBuffer<RowIndexT>{&ridx_a_, &ridx_b_};
110+
position_ = dh::DoubleBuffer<bst_node_t>{&position_a_, &position_b_};
111+
ridx_segments_.emplace_back(Segment(0, num_rows));
112+
113+
Reset(device_idx, ridx_.CurrentSpan(), position_.CurrentSpan());
109114
left_counts_.resize(256);
110115
thrust::fill(left_counts_.begin(), left_counts_.end(), 0);
111116
streams_.resize(2);
112-
ridx_segments_.emplace_back(Segment(0, num_rows));
113117
for (auto& stream : streams_) {
114118
dh::safe_cuda(cudaStreamCreate(&stream));
115119
}
@@ -129,15 +133,15 @@ common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(
129133
if (segment.Size() == 0) {
130134
return common::Span<const RowPartitioner::RowIndexT>();
131135
}
132-
return dh::ToSpan(ridx_a_).subspan(segment.begin, segment.Size());
136+
return ridx_.CurrentSpan().subspan(segment.begin, segment.Size());
133137
}
134138

135139
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows() {
136-
return dh::ToSpan(ridx_a_);
140+
return ridx_.CurrentSpan();
137141
}
138142

139143
common::Span<const bst_node_t> RowPartitioner::GetPosition() {
140-
return dh::ToSpan(position_a_);
144+
return position_.CurrentSpan();
141145
}
142146
std::vector<RowPartitioner::RowIndexT> RowPartitioner::GetRowsHost(
143147
bst_node_t nidx) {
@@ -159,25 +163,23 @@ void RowPartitioner::SortPositionAndCopy(const Segment& segment,
159163
bst_node_t right_nidx,
160164
int64_t* d_left_count,
161165
cudaStream_t stream) {
162-
dh::TemporaryArray<bst_node_t> position_temp(position_a_.size());
163-
dh::TemporaryArray<RowIndexT> ridx_temp(ridx_a_.size());
164166
SortPosition(
165167
// position_in
166-
common::Span<bst_node_t>(position_a_.data().get() + segment.begin,
168+
common::Span<bst_node_t>(position_.Current() + segment.begin,
167169
segment.Size()),
168170
// position_out
169-
common::Span<bst_node_t>(position_temp.data().get() + segment.begin,
171+
common::Span<bst_node_t>(position_.Other() + segment.begin,
170172
segment.Size()),
171173
// row index in
172-
common::Span<RowIndexT>(ridx_a_.data().get() + segment.begin, segment.Size()),
174+
common::Span<RowIndexT>(ridx_.Current() + segment.begin, segment.Size()),
173175
// row index out
174-
common::Span<RowIndexT>(ridx_temp.data().get() + segment.begin, segment.Size()),
176+
common::Span<RowIndexT>(ridx_.Other() + segment.begin, segment.Size()),
175177
left_nidx, right_nidx, d_left_count, stream);
176178
// Copy back key/value
177-
const auto d_position_current = position_a_.data().get() + segment.begin;
178-
const auto d_position_other = position_temp.data().get() + segment.begin;
179-
const auto d_ridx_current = ridx_a_.data().get() + segment.begin;
180-
const auto d_ridx_other = ridx_temp.data().get() + segment.begin;
179+
const auto d_position_current = position_.Current() + segment.begin;
180+
const auto d_position_other = position_.Other() + segment.begin;
181+
const auto d_ridx_current = ridx_.Current() + segment.begin;
182+
const auto d_ridx_other = ridx_.Other() + segment.begin;
181183
dh::LaunchN(device_idx_, segment.Size(), stream, [=] __device__(size_t idx) {
182184
d_position_current[idx] = d_position_other[idx];
183185
d_ridx_current[idx] = d_ridx_other[idx];

src/tree/gpu_hist/row_partitioner.cuh

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,17 @@ class RowPartitioner {
4747
/*! \brief Range of row index for each node, pointers into ridx below. */
4848
std::vector<Segment> ridx_segments_;
4949
dh::TemporaryArray<RowIndexT> ridx_a_;
50+
dh::TemporaryArray<RowIndexT> ridx_b_;
5051
dh::TemporaryArray<bst_node_t> position_a_;
52+
dh::TemporaryArray<bst_node_t> position_b_;
53+
/*! \brief mapping for node id -> rows.
54+
* This looks like:
55+
* node id | 1 | 2 |
56+
* rows idx | 3, 5, 1 | 13, 31 |
57+
*/
58+
dh::DoubleBuffer<RowIndexT> ridx_;
59+
/*! \brief mapping for row -> node id. */
60+
dh::DoubleBuffer<bst_node_t> position_;
5161
dh::caching_device_vector<int64_t>
5262
left_counts_; // Useful to keep a bunch of zeroed memory for sort position
5363
std::vector<cudaStream_t> streams_;
@@ -100,8 +110,8 @@ class RowPartitioner {
100110
void UpdatePosition(bst_node_t nidx, bst_node_t left_nidx,
101111
bst_node_t right_nidx, UpdatePositionOpT op) {
102112
Segment segment = ridx_segments_.at(nidx); // rows belongs to node nidx
103-
auto d_ridx = dh::ToSpan(ridx_a_);
104-
auto d_position = dh::ToSpan(position_a_);
113+
auto d_ridx = ridx_.CurrentSpan();
114+
auto d_position = position_.CurrentSpan();
105115
if (left_counts_.size() <= nidx) {
106116
left_counts_.resize((nidx * 2) + 1);
107117
thrust::fill(left_counts_.begin(), left_counts_.end(), 0);
@@ -148,9 +158,9 @@ class RowPartitioner {
148158
*/
149159
template <typename FinalisePositionOpT>
150160
void FinalisePosition(FinalisePositionOpT op) {
151-
auto d_position = position_a_.data().get();
152-
const auto d_ridx = ridx_a_.data().get();
153-
dh::LaunchN(device_idx_, position_a_.size(), [=] __device__(size_t idx) {
161+
auto d_position = position_.Current();
162+
const auto d_ridx = ridx_.Current();
163+
dh::LaunchN(device_idx_, position_.Size(), [=] __device__(size_t idx) {
154164
auto position = d_position[idx];
155165
RowIndexT ridx = d_ridx[idx];
156166
bst_node_t new_position = op(ridx, position);

0 commit comments

Comments
 (0)