4
4
* \brief Device-memory version of DMatrix.
5
5
*/
6
6
7
+ #include < thrust/execution_policy.h>
8
+ #include < thrust/iterator/discard_iterator.h>
9
+ #include < thrust/iterator/transform_output_iterator.h>
7
10
#include < xgboost/base.h>
8
11
#include < xgboost/data.h>
9
-
10
12
#include < memory>
11
- #include < thrust/execution_policy.h>
12
-
13
- #include < thrust/iterator/transform_output_iterator.h>
14
- #include < thrust/iterator/discard_iterator.h>
13
+ #include < utility>
14
+ #include " ../common/hist_util.h"
15
+ #include " ../common/math.h"
15
16
#include " adapter.h"
16
- #include " device_dmatrix.h"
17
17
#include " device_adapter.cuh"
18
18
#include " ellpack_page.cuh"
19
- #include " ../common/hist_util.h"
20
- #include " ../common/math.h"
19
+ #include " device_dmatrix.h"
21
20
22
21
namespace xgboost {
23
22
namespace data {
@@ -37,7 +36,7 @@ struct IsValidFunctor : public thrust::unary_function<Entry, bool> {
37
36
// Returns maximum row length
38
37
template <typename AdapterBatchT>
39
38
size_t GetRowCounts (const AdapterBatchT& batch, common::Span<size_t > offset,
40
- int device_idx, float missing) {
39
+ int device_idx, float missing) {
41
40
IsValidFunctor is_valid (missing);
42
41
// Count elements per row
43
42
dh::LaunchN (device_idx, batch.Size (), [=] __device__ (size_t idx) {
@@ -51,23 +50,23 @@ size_t GetRowCounts(const AdapterBatchT& batch, common::Span<size_t> offset,
51
50
dh::XGBCachingDeviceAllocator<char > alloc;
52
51
size_t row_stride = thrust::reduce (
53
52
thrust::cuda::par (alloc), thrust::device_pointer_cast (offset.data ()),
54
- thrust::device_pointer_cast (offset.data ()) + offset.size (),
55
- size_t (0 ),
56
- thrust::maximum<size_t >());
53
+ thrust::device_pointer_cast (offset.data ()) + offset.size (), size_t (0 ),
54
+ thrust::maximum<size_t >());
57
55
return row_stride;
58
56
}
59
57
60
- template <typename AdapterBatchT>
61
- struct WriteCompressedEllpackFunctor
62
- {
58
+ template <typename AdapterBatchT>
59
+ struct WriteCompressedEllpackFunctor {
63
60
WriteCompressedEllpackFunctor (common::CompressedByteT* buffer,
64
- const common::CompressedBufferWriter& writer, const AdapterBatchT& batch,const EllpackDeviceAccessor& accessor,const IsValidFunctor&is_valid)
65
- :
66
- d_buffer (buffer),
67
- writer (writer),
68
- batch (batch),accessor(accessor),is_valid(is_valid)
69
- {
70
- }
61
+ const common::CompressedBufferWriter& writer,
62
+ const AdapterBatchT& batch,
63
+ EllpackDeviceAccessor accessor,
64
+ const IsValidFunctor& is_valid)
65
+ : d_buffer(buffer),
66
+ writer (writer),
67
+ batch(batch),
68
+ accessor(std::move(accessor)),
69
+ is_valid(is_valid) {}
71
70
72
71
common::CompressedByteT* d_buffer;
73
72
common::CompressedBufferWriter writer;
@@ -76,55 +75,57 @@ struct WriteCompressedEllpackFunctor
76
75
IsValidFunctor is_valid;
77
76
78
77
using Tuple = thrust::tuple<size_t , size_t , size_t >;
79
- __device__ size_t operator ()(Tuple out)
80
- {
78
+ __device__ size_t operator ()(Tuple out) {
81
79
auto e = batch.GetElement (out.get <2 >());
82
80
if (is_valid (e)) {
83
81
// -1 because the scan is inclusive
84
- size_t output_position = accessor.row_stride * e.row_idx + out.get <1 >() - 1 ;
82
+ size_t output_position =
83
+ accessor.row_stride * e.row_idx + out.get <1 >() - 1 ;
85
84
auto bin_idx = accessor.SearchBin (e.value , e.column_idx );
86
85
writer.AtomicWriteSymbol (d_buffer, bin_idx, output_position);
87
86
}
88
87
return 0 ;
89
-
90
88
}
91
89
};
92
90
93
91
// Here the data is already correctly ordered and simply needs to be compacted
94
92
// to remove missing data
95
93
template <typename AdapterBatchT>
96
- void CopyDataRowMajor (const AdapterBatchT& batch, EllpackPageImpl*dst,
97
- int device_idx, float missing,common::Span<size_t > row_counts) {
94
+ void CopyDataRowMajor (const AdapterBatchT& batch, EllpackPageImpl* dst,
95
+ int device_idx, float missing,
96
+ common::Span<size_t > row_counts) {
98
97
// Some witchcraft happens here
99
- // The goal is to copy valid elements out of the input to an ellpack matrix with a given row stride, using no extra working memory
100
- // Standard stream compaction needs to be modified to do this, so we manually define a segmented stream compaction via operators on an inclusive scan. The output of this inclusive scan is fed to a custom function which works out the correct output position
98
+ // The goal is to copy valid elements out of the input to an ellpack matrix
99
+ // with a given row stride, using no extra working memory Standard stream
100
+ // compaction needs to be modified to do this, so we manually define a
101
+ // segmented stream compaction via operators on an inclusive scan. The output
102
+ // of this inclusive scan is fed to a custom function which works out the
103
+ // correct output position
101
104
auto counting = thrust::make_counting_iterator (0llu);
102
105
IsValidFunctor is_valid (missing);
103
- auto key_iter = dh::MakeTransformIterator<size_t >(counting,[=]__device__ (size_t idx)
104
- {
105
- return batch.GetElement (idx).row_idx ;
106
- });
107
- auto value_iter = dh::MakeTransformIterator<size_t >(
106
+ auto key_iter = dh::MakeTransformIterator<size_t >(
108
107
counting,
109
- [=]__device__ (size_t idx) -> size_t
110
- {
111
- return is_valid (batch.GetElement (idx));
112
- });
108
+ [=] __device__ (size_t idx) { return batch.GetElement (idx).row_idx ; });
109
+ auto value_iter = dh::MakeTransformIterator<size_t >(
110
+ counting, [=] __device__ (size_t idx) -> size_t {
111
+ return is_valid (batch.GetElement (idx));
112
+ });
113
113
114
- auto key_value_index_iter = thrust::make_zip_iterator (thrust::make_tuple (key_iter, value_iter, counting));
114
+ auto key_value_index_iter = thrust::make_zip_iterator (
115
+ thrust::make_tuple (key_iter, value_iter, counting));
115
116
116
117
// Tuple[0] = The row index of the input, used as a key to define segments
117
118
// Tuple[1] = Scanned flags of valid elements for each row
118
119
// Tuple[2] = The index in the input data
119
- using Tuple = thrust::tuple<size_t , size_t , size_t >;
120
+ using Tuple = thrust::tuple<size_t , size_t , size_t >;
120
121
121
122
auto device_accessor = dst->GetDeviceAccessor (device_idx);
122
123
common::CompressedBufferWriter writer (device_accessor.NumSymbols ());
123
124
auto d_compressed_buffer = dst->gidx_buffer .DevicePointer ();
124
125
125
126
// We redirect the scan output into this functor to do the actual writing
126
- WriteCompressedEllpackFunctor<AdapterBatchT> functor (d_compressed_buffer, writer,
127
- batch, device_accessor, is_valid);
127
+ WriteCompressedEllpackFunctor<AdapterBatchT> functor (
128
+ d_compressed_buffer, writer, batch, device_accessor, is_valid);
128
129
thrust::discard_iterator<size_t > discard;
129
130
thrust::transform_output_iterator<
130
131
WriteCompressedEllpackFunctor<AdapterBatchT>, decltype (discard)>
@@ -153,8 +154,8 @@ void CopyDataColumnMajor(AdapterT* adapter, const AdapterBatchT& batch,
153
154
dh::LaunchN (adapter->DeviceIdx (), batch.Size (), [=] __device__ (size_t idx) {
154
155
const auto & e = batch.GetElement (idx);
155
156
atomicAdd (reinterpret_cast <unsigned long long *>( // NOLINT
156
- &d_column_sizes[e.column_idx ]),
157
- static_cast <unsigned long long >(1 )); // NOLINT
157
+ &d_column_sizes[e.column_idx ]),
158
+ static_cast <unsigned long long >(1 )); // NOLINT
158
159
});
159
160
160
161
thrust::host_vector<size_t > host_column_sizes = column_sizes;
@@ -173,59 +174,57 @@ void CopyDataColumnMajor(AdapterT* adapter, const AdapterBatchT& batch,
173
174
size_t end = begin + size;
174
175
dh::LaunchN (adapter->DeviceIdx (), end - begin, [=] __device__ (size_t idx) {
175
176
auto writer_non_const =
176
- writer; // For some reason this variable gets captured as const
177
+ writer; // For some reason this variable gets captured as const
177
178
const auto & e = batch.GetElement (idx + begin);
178
179
if (!is_valid (e)) return ;
179
- size_t output_position = e.row_idx * row_stride + d_temp_row_ptr[e.row_idx ];
180
+ size_t output_position =
181
+ e.row_idx * row_stride + d_temp_row_ptr[e.row_idx ];
180
182
auto bin_idx = device_accessor.SearchBin (e.value , e.column_idx );
181
- writer_non_const.AtomicWriteSymbol (d_compressed_buffer, bin_idx, output_position);
183
+ writer_non_const.AtomicWriteSymbol (d_compressed_buffer, bin_idx,
184
+ output_position);
182
185
d_temp_row_ptr[e.row_idx ] += 1 ;
183
186
});
184
187
185
188
begin = end;
186
189
}
187
190
}
188
191
189
- void WriteNullValues (EllpackPageImpl*dst,
190
- int device_idx, common::Span<size_t > row_counts)
191
- {
192
- // Write the null values
192
+ void WriteNullValues (EllpackPageImpl* dst, int device_idx,
193
+ common::Span<size_t > row_counts) {
194
+ // Write the null values
193
195
auto device_accessor = dst->GetDeviceAccessor (device_idx);
194
196
common::CompressedBufferWriter writer (device_accessor.NumSymbols ());
195
197
auto d_compressed_buffer = dst->gidx_buffer .DevicePointer ();
196
198
auto row_stride = dst->row_stride ;
197
- dh::LaunchN (device_idx, row_stride * dst->n_rows , [=] __device__ (size_t idx)
198
- {
199
+ dh::LaunchN (device_idx, row_stride * dst->n_rows , [=] __device__ (size_t idx) {
199
200
auto writer_non_const =
200
- writer; // For some reason this variable gets captured as const
201
+ writer; // For some reason this variable gets captured as const
201
202
size_t row_idx = idx / row_stride;
202
203
size_t row_offset = idx % row_stride;
203
- if (row_offset >= row_counts[row_idx])
204
- {
204
+ if (row_offset >= row_counts[row_idx]) {
205
205
writer_non_const.AtomicWriteSymbol (d_compressed_buffer,
206
206
device_accessor.NullValue (), idx);
207
207
}
208
208
});
209
-
210
- }
209
+ }
211
210
// Does not currently support metainfo as no on-device data source contains this
212
211
// Current implementation assumes a single batch. More batches can
213
212
// be supported in future. Does not currently support inferring row/column size
214
- template <typename AdapterT>
213
+ template <typename AdapterT>
215
214
DeviceDMatrix::DeviceDMatrix (AdapterT* adapter, float missing, int nthread) {
216
- common::HistogramCuts cuts = common::AdapterDeviceSketch (adapter, 256 , missing);
217
- auto & batch = adapter->Value ();
215
+ common::HistogramCuts cuts =
216
+ common::AdapterDeviceSketch (adapter, 256 , missing);
217
+ auto & batch = adapter->Value ();
218
218
// Work out how many valid entries we have in each row
219
- dh::caching_device_vector<size_t > row_counts (adapter->NumRows () + 1 ,
220
- 0 );
221
- common::Span< size_t > row_counts_span ( row_counts.data (). get (),row_counts. size () );
219
+ dh::caching_device_vector<size_t > row_counts (adapter->NumRows () + 1 , 0 );
220
+ common::Span< size_t > row_counts_span (row_counts. data (). get (),
221
+ row_counts.size ());
222
222
size_t row_stride =
223
223
GetRowCounts (batch, row_counts_span, adapter->DeviceIdx (), missing);
224
224
225
225
dh::XGBCachingDeviceAllocator<char > alloc;
226
226
info.num_nonzero_ = thrust::reduce (thrust::cuda::par (alloc),
227
- row_counts.begin (),
228
- row_counts.end ());
227
+ row_counts.begin (), row_counts.end ());
229
228
info.num_col_ = adapter->NumColumns ();
230
229
info.num_row_ = adapter->NumRows ();
231
230
ellpack_page_.reset (new EllpackPage ());
@@ -239,8 +238,7 @@ DeviceDMatrix::DeviceDMatrix(AdapterT* adapter, float missing, int nthread) {
239
238
CopyDataColumnMajor (adapter, batch, ellpack_page_->Impl (), missing);
240
239
}
241
240
242
- WriteNullValues (ellpack_page_->Impl (), adapter->DeviceIdx (),
243
- row_counts_span);
241
+ WriteNullValues (ellpack_page_->Impl (), adapter->DeviceIdx (), row_counts_span);
244
242
245
243
// Synchronise worker columns
246
244
rabit::Allreduce<rabit::op::Max>(&info.num_col_ , 1 );
0 commit comments