@@ -57,80 +57,51 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
57
57
size_t nnz = 0 ;
58
58
// Sketch for all batches.
59
59
iter.Reset ();
60
- common::HistogramCuts cuts;
61
- common::DenseCuts dense_cuts (&cuts);
62
60
63
61
std::vector<common::SketchContainer> sketch_containers;
64
62
size_t batches = 0 ;
65
63
size_t accumulated_rows = 0 ;
66
64
bst_feature_t cols = 0 ;
65
+ int32_t device = -1 ;
67
66
while (iter.Next ()) {
68
- auto device = proxy->DeviceIdx ();
67
+ device = proxy->DeviceIdx ();
69
68
dh::safe_cuda (cudaSetDevice (device));
70
69
if (cols == 0 ) {
71
70
cols = num_cols ();
72
71
} else {
73
72
CHECK_EQ (cols, num_cols ()) << " Inconsistent number of columns." ;
74
73
}
75
- sketch_containers.emplace_back (batch_param_.max_bin , num_cols (), num_rows ());
74
+ sketch_containers.emplace_back (batch_param_.max_bin , num_cols (), num_rows (), device );
76
75
auto * p_sketch = &sketch_containers.back ();
77
- if (proxy->Info ().weights_ .Size () != 0 ) {
78
76
proxy->Info ().weights_ .SetDevice (device);
79
77
Dispatch (proxy, [&](auto const &value) {
80
- common::AdapterDeviceSketchWeighted (value, batch_param_.max_bin ,
81
- proxy->Info (),
82
- missing, device, p_sketch);
83
- });
84
- } else {
85
- Dispatch (proxy, [&](auto const &value) {
86
- common::AdapterDeviceSketch (value, batch_param_.max_bin , missing,
87
- device, p_sketch);
88
- });
89
- }
78
+ common::AdapterDeviceSketchWeighted (value, batch_param_.max_bin ,
79
+ proxy->Info (), missing, p_sketch);
80
+ });
90
81
91
- auto batch_rows = num_rows ();
92
- accumulated_rows += batch_rows;
93
- dh::caching_device_vector<size_t > row_counts (batch_rows + 1 , 0 );
94
- common::Span<size_t > row_counts_span (row_counts.data ().get (),
95
- row_counts.size ());
96
- row_stride =
97
- std::max (row_stride, Dispatch (proxy, [=]( auto const & value) {
98
- return GetRowCounts (value, row_counts_span, device, missing);
99
- }));
100
- nnz += thrust::reduce (thrust::cuda::par (alloc),
101
- row_counts. begin (), row_counts.end ());
102
- batches++;
82
+ auto batch_rows = num_rows ();
83
+ accumulated_rows += batch_rows;
84
+ dh::caching_device_vector<size_t > row_counts (batch_rows + 1 , 0 );
85
+ common::Span<size_t > row_counts_span (row_counts.data ().get (),
86
+ row_counts.size ());
87
+ row_stride = std::max (row_stride, Dispatch (proxy, [=]( auto const &value) {
88
+ return GetRowCounts ( value, row_counts_span,
89
+ device, missing);
90
+ }));
91
+ nnz += thrust::reduce (thrust::cuda::par (alloc), row_counts. begin ( ),
92
+ row_counts.end ());
93
+ batches++;
103
94
}
104
95
105
- // Merging multiple batches for each column
106
- std::vector<common::WQSketch::SummaryContainer> summary_array (cols);
107
- size_t intermediate_num_cuts = std::min (
108
- accumulated_rows, static_cast <size_t >(batch_param_.max_bin *
109
- common::SketchContainer::kFactor ));
110
- size_t nbytes =
111
- common::WQSketch::SummaryContainer::CalcMemCost (intermediate_num_cuts);
112
- #pragma omp parallel for num_threads(nthread) if (nthread > 0)
113
- for (omp_ulong c = 0 ; c < cols; ++c) {
114
- for (auto & sketch_batch : sketch_containers) {
115
- common::WQSketch::SummaryContainer summary;
116
- sketch_batch.sketches_ .at (c).GetSummary (&summary);
117
- sketch_batch.sketches_ .at (c).Init (0 , 1 );
118
- summary_array.at (c).Reduce (summary, nbytes);
119
- }
96
+ common::SketchContainer final_sketch (batch_param_.max_bin , cols, accumulated_rows, device);
97
+ for (auto const & sketch: sketch_containers) {
98
+ final_sketch.Merge (sketch.ColumnsPtr (), sketch.Data ());
120
99
}
121
100
sketch_containers.clear ();
101
+ sketch_containers.shrink_to_fit ();
122
102
123
- // Build the final summary.
124
- std::vector<common::WQSketch> sketches (cols);
125
- #pragma omp parallel for num_threads(nthread) if (nthread > 0)
126
- for (omp_ulong c = 0 ; c < cols; ++c) {
127
- sketches.at (c).Init (
128
- accumulated_rows,
129
- 1.0 / (common::SketchContainer::kFactor * batch_param_.max_bin ));
130
- sketches.at (c).PushSummary (summary_array.at (c));
131
- }
132
- dense_cuts.Init (&sketches, batch_param_.max_bin , accumulated_rows);
133
- summary_array.clear ();
103
+ common::HistogramCuts cuts;
104
+ final_sketch.MakeCuts (&cuts);
134
105
135
106
this ->info_ .num_col_ = cols;
136
107
this ->info_ .num_row_ = accumulated_rows;
0 commit comments