@@ -103,13 +103,17 @@ void Reset(int device_idx, common::Span<RowPartitioner::RowIndexT> ridx,
103
103
}
104
104
105
105
RowPartitioner::RowPartitioner (int device_idx, size_t num_rows)
106
- : device_idx_(device_idx), ridx_a_(num_rows), position_a_(num_rows) {
106
+ : device_idx_(device_idx), ridx_a_(num_rows), position_a_(num_rows),
107
+ ridx_b_ (num_rows), position_b_(num_rows) {
107
108
dh::safe_cuda (cudaSetDevice (device_idx_));
108
- Reset (device_idx, dh::ToSpan (ridx_a_), dh::ToSpan (position_a_));
109
+ ridx_ = dh::DoubleBuffer<RowIndexT>{&ridx_a_, &ridx_b_};
110
+ position_ = dh::DoubleBuffer<bst_node_t >{&position_a_, &position_b_};
111
+ ridx_segments_.emplace_back (Segment (0 , num_rows));
112
+
113
+ Reset (device_idx, ridx_.CurrentSpan (), position_.CurrentSpan ());
109
114
left_counts_.resize (256 );
110
115
thrust::fill (left_counts_.begin (), left_counts_.end (), 0 );
111
116
streams_.resize (2 );
112
- ridx_segments_.emplace_back (Segment (0 , num_rows));
113
117
for (auto & stream : streams_) {
114
118
dh::safe_cuda (cudaStreamCreate (&stream));
115
119
}
@@ -129,15 +133,15 @@ common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(
129
133
if (segment.Size () == 0 ) {
130
134
return common::Span<const RowPartitioner::RowIndexT>();
131
135
}
132
- return dh::ToSpan (ridx_a_ ).subspan (segment.begin , segment.Size ());
136
+ return ridx_. CurrentSpan ( ).subspan (segment.begin , segment.Size ());
133
137
}
134
138
135
139
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows () {
136
- return dh::ToSpan (ridx_a_ );
140
+ return ridx_. CurrentSpan ( );
137
141
}
138
142
139
143
common::Span<const bst_node_t > RowPartitioner::GetPosition () {
140
- return dh::ToSpan (position_a_ );
144
+ return position_. CurrentSpan ( );
141
145
}
142
146
std::vector<RowPartitioner::RowIndexT> RowPartitioner::GetRowsHost (
143
147
bst_node_t nidx) {
@@ -159,25 +163,23 @@ void RowPartitioner::SortPositionAndCopy(const Segment& segment,
159
163
bst_node_t right_nidx,
160
164
int64_t * d_left_count,
161
165
cudaStream_t stream) {
162
- dh::TemporaryArray<bst_node_t > position_temp (position_a_.size ());
163
- dh::TemporaryArray<RowIndexT> ridx_temp (ridx_a_.size ());
164
166
SortPosition (
165
167
// position_in
166
- common::Span<bst_node_t >(position_a_. data (). get () + segment.begin ,
168
+ common::Span<bst_node_t >(position_. Current () + segment.begin ,
167
169
segment.Size ()),
168
170
// position_out
169
- common::Span<bst_node_t >(position_temp. data (). get () + segment.begin ,
171
+ common::Span<bst_node_t >(position_. Other () + segment.begin ,
170
172
segment.Size ()),
171
173
// row index in
172
- common::Span<RowIndexT>(ridx_a_. data (). get () + segment.begin , segment.Size ()),
174
+ common::Span<RowIndexT>(ridx_. Current () + segment.begin , segment.Size ()),
173
175
// row index out
174
- common::Span<RowIndexT>(ridx_temp. data (). get () + segment.begin , segment.Size ()),
176
+ common::Span<RowIndexT>(ridx_. Other () + segment.begin , segment.Size ()),
175
177
left_nidx, right_nidx, d_left_count, stream);
176
178
// Copy back key/value
177
- const auto d_position_current = position_a_. data (). get () + segment.begin ;
178
- const auto d_position_other = position_temp. data (). get () + segment.begin ;
179
- const auto d_ridx_current = ridx_a_. data (). get () + segment.begin ;
180
- const auto d_ridx_other = ridx_temp. data (). get () + segment.begin ;
179
+ const auto d_position_current = position_. Current () + segment.begin ;
180
+ const auto d_position_other = position_. Other () + segment.begin ;
181
+ const auto d_ridx_current = ridx_. Current () + segment.begin ;
182
+ const auto d_ridx_other = ridx_. Other () + segment.begin ;
181
183
dh::LaunchN (device_idx_, segment.Size (), stream, [=] __device__ (size_t idx) {
182
184
d_position_current[idx] = d_position_other[idx];
183
185
d_ridx_current[idx] = d_ridx_other[idx];
0 commit comments