Skip to content

Commit 0a1038f

Browse files
committed
Rebase for iterative DMatrix.
1 parent a612555 commit 0a1038f

File tree

2 files changed

+26
-56
lines changed

2 files changed

+26
-56
lines changed

src/common/quantile.cuh

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,9 @@ class SketchContainer {
121121

122122
Span<OffsetT const> ColumnsPtr() const { return this->columns_ptr_.ConstDeviceSpan(); }
123123

124-
// Prevent copying/assigning/moving this as its internals can't be
125-
// assigned/copied/moved
124+
SketchContainer(SketchContainer&&) = default;
125+
126126
SketchContainer(const SketchContainer&) = delete;
127-
SketchContainer(const SketchContainer&&) = delete;
128127
SketchContainer& operator=(const SketchContainer&) = delete;
129128
SketchContainer& operator=(const SketchContainer&&) = delete;
130129
};

src/data/iterative_device_dmatrix.cu

Lines changed: 24 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -57,80 +57,51 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
5757
size_t nnz = 0;
5858
// Sketch for all batches.
5959
iter.Reset();
60-
common::HistogramCuts cuts;
61-
common::DenseCuts dense_cuts(&cuts);
6260

6361
std::vector<common::SketchContainer> sketch_containers;
6462
size_t batches = 0;
6563
size_t accumulated_rows = 0;
6664
bst_feature_t cols = 0;
65+
int32_t device = -1;
6766
while (iter.Next()) {
68-
auto device = proxy->DeviceIdx();
67+
device = proxy->DeviceIdx();
6968
dh::safe_cuda(cudaSetDevice(device));
7069
if (cols == 0) {
7170
cols = num_cols();
7271
} else {
7372
CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns.";
7473
}
75-
sketch_containers.emplace_back(batch_param_.max_bin, num_cols(), num_rows());
74+
sketch_containers.emplace_back(batch_param_.max_bin, num_cols(), num_rows(), device);
7675
auto* p_sketch = &sketch_containers.back();
77-
if (proxy->Info().weights_.Size() != 0) {
7876
proxy->Info().weights_.SetDevice(device);
7977
Dispatch(proxy, [&](auto const &value) {
80-
common::AdapterDeviceSketchWeighted(value, batch_param_.max_bin,
81-
proxy->Info(),
82-
missing, device, p_sketch);
83-
});
84-
} else {
85-
Dispatch(proxy, [&](auto const &value) {
86-
common::AdapterDeviceSketch(value, batch_param_.max_bin, missing,
87-
device, p_sketch);
88-
});
89-
}
78+
common::AdapterDeviceSketchWeighted(value, batch_param_.max_bin,
79+
proxy->Info(), missing, p_sketch);
80+
});
9081

91-
auto batch_rows = num_rows();
92-
accumulated_rows += batch_rows;
93-
dh::caching_device_vector<size_t> row_counts(batch_rows + 1, 0);
94-
common::Span<size_t> row_counts_span(row_counts.data().get(),
95-
row_counts.size());
96-
row_stride =
97-
std::max(row_stride, Dispatch(proxy, [=](auto const& value) {
98-
return GetRowCounts(value, row_counts_span, device, missing);
99-
}));
100-
nnz += thrust::reduce(thrust::cuda::par(alloc),
101-
row_counts.begin(), row_counts.end());
102-
batches++;
82+
auto batch_rows = num_rows();
83+
accumulated_rows += batch_rows;
84+
dh::caching_device_vector<size_t> row_counts(batch_rows + 1, 0);
85+
common::Span<size_t> row_counts_span(row_counts.data().get(),
86+
row_counts.size());
87+
row_stride = std::max(row_stride, Dispatch(proxy, [=](auto const &value) {
88+
return GetRowCounts(value, row_counts_span,
89+
device, missing);
90+
}));
91+
nnz += thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(),
92+
row_counts.end());
93+
batches++;
10394
}
10495

105-
// Merging multiple batches for each column
106-
std::vector<common::WQSketch::SummaryContainer> summary_array(cols);
107-
size_t intermediate_num_cuts = std::min(
108-
accumulated_rows, static_cast<size_t>(batch_param_.max_bin *
109-
common::SketchContainer::kFactor));
110-
size_t nbytes =
111-
common::WQSketch::SummaryContainer::CalcMemCost(intermediate_num_cuts);
112-
#pragma omp parallel for num_threads(nthread) if (nthread > 0)
113-
for (omp_ulong c = 0; c < cols; ++c) {
114-
for (auto& sketch_batch : sketch_containers) {
115-
common::WQSketch::SummaryContainer summary;
116-
sketch_batch.sketches_.at(c).GetSummary(&summary);
117-
sketch_batch.sketches_.at(c).Init(0, 1);
118-
summary_array.at(c).Reduce(summary, nbytes);
119-
}
96+
common::SketchContainer final_sketch(batch_param_.max_bin, cols, accumulated_rows, device);
97+
for (auto const& sketch: sketch_containers) {
98+
final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data());
12099
}
121100
sketch_containers.clear();
101+
sketch_containers.shrink_to_fit();
122102

123-
// Build the final summary.
124-
std::vector<common::WQSketch> sketches(cols);
125-
#pragma omp parallel for num_threads(nthread) if (nthread > 0)
126-
for (omp_ulong c = 0; c < cols; ++c) {
127-
sketches.at(c).Init(
128-
accumulated_rows,
129-
1.0 / (common::SketchContainer::kFactor * batch_param_.max_bin));
130-
sketches.at(c).PushSummary(summary_array.at(c));
131-
}
132-
dense_cuts.Init(&sketches, batch_param_.max_bin, accumulated_rows);
133-
summary_array.clear();
103+
common::HistogramCuts cuts;
104+
final_sketch.MakeCuts(&cuts);
134105

135106
this->info_.num_col_ = cols;
136107
this->info_.num_row_ = accumulated_rows;

0 commit comments

Comments
 (0)