Skip to content

Commit caa4f32

Browse files
committed
Lint
1 parent caa0132 commit caa4f32

File tree

4 files changed

+83
-88
lines changed

4 files changed

+83
-88
lines changed

python-package/xgboost/core.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,13 +1021,12 @@ def __init__(self, data, label=None, weight=None, base_margin=None,
10211021
print(type(data))
10221022
raise ValueError('Only cupy/cudf currently supported for DeviceDMatrix')
10231023

1024-
1025-
super().__init__(data,label=label,weight=weight, base_margin=base_margin,
1026-
missing=missing,
1027-
silent=silent,
1028-
feature_names=feature_names,
1029-
feature_types=feature_types,
1030-
nthread=nthread)
1024+
super().__init__(data, label=label, weight=weight, base_margin=base_margin,
1025+
missing=missing,
1026+
silent=silent,
1027+
feature_names=feature_names,
1028+
feature_types=feature_types,
1029+
nthread=nthread)
10311030

10321031
def _init_from_array_interface(self, data, missing, nthread):
10331032
"""Initialize DMatrix from cupy ndarray."""

src/common/compressed_iterator.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ static const int kPadding = 4; // Assign padding so we can read slightly off
3333

3434
// The number of bits required to represent a given unsigned range
3535
inline XGBOOST_DEVICE size_t SymbolBits(size_t num_symbols) {
36-
auto bits = std::ceil(log2(double(num_symbols)));
36+
auto bits = std::ceil(log2(static_cast<double>(num_symbols)));
3737
return std::max(static_cast<size_t>(bits), size_t(1));
3838
}
3939
} // namespace detail
@@ -53,8 +53,8 @@ class CompressedBufferWriter {
5353
size_t symbol_bits_;
5454

5555
public:
56-
XGBOOST_DEVICE explicit CompressedBufferWriter(size_t num_symbols):symbol_bits_(detail::SymbolBits(num_symbols)) {
57-
}
56+
XGBOOST_DEVICE explicit CompressedBufferWriter(size_t num_symbols)
57+
: symbol_bits_(detail::SymbolBits(num_symbols)) {}
5858

5959
/**
6060
* \fn static size_t CompressedBufferWriter::CalculateBufferSize(int

src/data/device_dmatrix.cu

Lines changed: 67 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,19 @@
44
* \brief Device-memory version of DMatrix.
55
*/
66

7+
#include <thrust/execution_policy.h>
8+
#include <thrust/iterator/discard_iterator.h>
9+
#include <thrust/iterator/transform_output_iterator.h>
710
#include <xgboost/base.h>
811
#include <xgboost/data.h>
9-
1012
#include <memory>
11-
#include <thrust/execution_policy.h>
12-
13-
#include <thrust/iterator/transform_output_iterator.h>
14-
#include <thrust/iterator/discard_iterator.h>
13+
#include <utility>
14+
#include "../common/hist_util.h"
15+
#include "../common/math.h"
1516
#include "adapter.h"
16-
#include "device_dmatrix.h"
1717
#include "device_adapter.cuh"
1818
#include "ellpack_page.cuh"
19-
#include "../common/hist_util.h"
20-
#include "../common/math.h"
19+
#include "device_dmatrix.h"
2120

2221
namespace xgboost {
2322
namespace data {
@@ -37,7 +36,7 @@ struct IsValidFunctor : public thrust::unary_function<Entry, bool> {
3736
// Returns maximum row length
3837
template <typename AdapterBatchT>
3938
size_t GetRowCounts(const AdapterBatchT& batch, common::Span<size_t> offset,
40-
int device_idx, float missing) {
39+
int device_idx, float missing) {
4140
IsValidFunctor is_valid(missing);
4241
// Count elements per row
4342
dh::LaunchN(device_idx, batch.Size(), [=] __device__(size_t idx) {
@@ -51,23 +50,23 @@ size_t GetRowCounts(const AdapterBatchT& batch, common::Span<size_t> offset,
5150
dh::XGBCachingDeviceAllocator<char> alloc;
5251
size_t row_stride = thrust::reduce(
5352
thrust::cuda::par(alloc), thrust::device_pointer_cast(offset.data()),
54-
thrust::device_pointer_cast(offset.data()) + offset.size(),
55-
size_t(0),
56-
thrust::maximum<size_t>());
53+
thrust::device_pointer_cast(offset.data()) + offset.size(), size_t(0),
54+
thrust::maximum<size_t>());
5755
return row_stride;
5856
}
5957

60-
template <typename AdapterBatchT>
61-
struct WriteCompressedEllpackFunctor
62-
{
58+
template <typename AdapterBatchT>
59+
struct WriteCompressedEllpackFunctor {
6360
WriteCompressedEllpackFunctor(common::CompressedByteT* buffer,
64-
const common::CompressedBufferWriter& writer, const AdapterBatchT& batch,const EllpackDeviceAccessor& accessor,const IsValidFunctor&is_valid)
65-
:
66-
d_buffer(buffer),
67-
writer(writer),
68-
batch(batch),accessor(accessor),is_valid(is_valid)
69-
{
70-
}
61+
const common::CompressedBufferWriter& writer,
62+
const AdapterBatchT& batch,
63+
EllpackDeviceAccessor accessor,
64+
const IsValidFunctor& is_valid)
65+
: d_buffer(buffer),
66+
writer(writer),
67+
batch(batch),
68+
accessor(std::move(accessor)),
69+
is_valid(is_valid) {}
7170

7271
common::CompressedByteT* d_buffer;
7372
common::CompressedBufferWriter writer;
@@ -76,55 +75,57 @@ struct WriteCompressedEllpackFunctor
7675
IsValidFunctor is_valid;
7776

7877
using Tuple = thrust::tuple<size_t, size_t, size_t>;
79-
__device__ size_t operator()(Tuple out)
80-
{
78+
__device__ size_t operator()(Tuple out) {
8179
auto e = batch.GetElement(out.get<2>());
8280
if (is_valid(e)) {
8381
// -1 because the scan is inclusive
84-
size_t output_position = accessor.row_stride * e.row_idx + out.get<1>() - 1;
82+
size_t output_position =
83+
accessor.row_stride * e.row_idx + out.get<1>() - 1;
8584
auto bin_idx = accessor.SearchBin(e.value, e.column_idx);
8685
writer.AtomicWriteSymbol(d_buffer, bin_idx, output_position);
8786
}
8887
return 0;
89-
9088
}
9189
};
9290

9391
// Here the data is already correctly ordered and simply needs to be compacted
9492
// to remove missing data
9593
template <typename AdapterBatchT>
96-
void CopyDataRowMajor(const AdapterBatchT& batch, EllpackPageImpl*dst,
97-
int device_idx, float missing,common::Span<size_t> row_counts) {
94+
void CopyDataRowMajor(const AdapterBatchT& batch, EllpackPageImpl* dst,
95+
int device_idx, float missing,
96+
common::Span<size_t> row_counts) {
9897
// Some witchcraft happens here
99-
// The goal is to copy valid elements out of the input to an ellpack matrix with a given row stride, using no extra working memory
100-
// Standard stream compaction needs to be modified to do this, so we manually define a segmented stream compaction via operators on an inclusive scan. The output of this inclusive scan is fed to a custom function which works out the correct output position
98+
// The goal is to copy valid elements out of the input to an ellpack matrix
99+
// with a given row stride, using no extra working memory Standard stream
100+
// compaction needs to be modified to do this, so we manually define a
101+
// segmented stream compaction via operators on an inclusive scan. The output
102+
// of this inclusive scan is fed to a custom function which works out the
103+
// correct output position
101104
auto counting = thrust::make_counting_iterator(0llu);
102105
IsValidFunctor is_valid(missing);
103-
auto key_iter = dh::MakeTransformIterator<size_t >(counting,[=]__device__ (size_t idx)
104-
{
105-
return batch.GetElement(idx).row_idx;
106-
});
107-
auto value_iter = dh::MakeTransformIterator<size_t>(
106+
auto key_iter = dh::MakeTransformIterator<size_t>(
108107
counting,
109-
[=]__device__ (size_t idx) -> size_t
110-
{
111-
return is_valid(batch.GetElement(idx));
112-
});
108+
[=] __device__(size_t idx) { return batch.GetElement(idx).row_idx; });
109+
auto value_iter = dh::MakeTransformIterator<size_t>(
110+
counting, [=] __device__(size_t idx) -> size_t {
111+
return is_valid(batch.GetElement(idx));
112+
});
113113

114-
auto key_value_index_iter = thrust::make_zip_iterator(thrust::make_tuple(key_iter, value_iter, counting));
114+
auto key_value_index_iter = thrust::make_zip_iterator(
115+
thrust::make_tuple(key_iter, value_iter, counting));
115116

116117
// Tuple[0] = The row index of the input, used as a key to define segments
117118
// Tuple[1] = Scanned flags of valid elements for each row
118119
// Tuple[2] = The index in the input data
119-
using Tuple = thrust::tuple<size_t , size_t , size_t >;
120+
using Tuple = thrust::tuple<size_t, size_t, size_t>;
120121

121122
auto device_accessor = dst->GetDeviceAccessor(device_idx);
122123
common::CompressedBufferWriter writer(device_accessor.NumSymbols());
123124
auto d_compressed_buffer = dst->gidx_buffer.DevicePointer();
124125

125126
// We redirect the scan output into this functor to do the actual writing
126-
WriteCompressedEllpackFunctor<AdapterBatchT> functor(d_compressed_buffer, writer,
127-
batch, device_accessor, is_valid);
127+
WriteCompressedEllpackFunctor<AdapterBatchT> functor(
128+
d_compressed_buffer, writer, batch, device_accessor, is_valid);
128129
thrust::discard_iterator<size_t> discard;
129130
thrust::transform_output_iterator<
130131
WriteCompressedEllpackFunctor<AdapterBatchT>, decltype(discard)>
@@ -153,8 +154,8 @@ void CopyDataColumnMajor(AdapterT* adapter, const AdapterBatchT& batch,
153154
dh::LaunchN(adapter->DeviceIdx(), batch.Size(), [=] __device__(size_t idx) {
154155
const auto& e = batch.GetElement(idx);
155156
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
156-
&d_column_sizes[e.column_idx]),
157-
static_cast<unsigned long long>(1)); // NOLINT
157+
&d_column_sizes[e.column_idx]),
158+
static_cast<unsigned long long>(1)); // NOLINT
158159
});
159160

160161
thrust::host_vector<size_t> host_column_sizes = column_sizes;
@@ -173,59 +174,57 @@ void CopyDataColumnMajor(AdapterT* adapter, const AdapterBatchT& batch,
173174
size_t end = begin + size;
174175
dh::LaunchN(adapter->DeviceIdx(), end - begin, [=] __device__(size_t idx) {
175176
auto writer_non_const =
176-
writer; // For some reason this variable gets captured as const
177+
writer; // For some reason this variable gets captured as const
177178
const auto& e = batch.GetElement(idx + begin);
178179
if (!is_valid(e)) return;
179-
size_t output_position = e.row_idx * row_stride + d_temp_row_ptr[e.row_idx];
180+
size_t output_position =
181+
e.row_idx * row_stride + d_temp_row_ptr[e.row_idx];
180182
auto bin_idx = device_accessor.SearchBin(e.value, e.column_idx);
181-
writer_non_const.AtomicWriteSymbol(d_compressed_buffer, bin_idx, output_position);
183+
writer_non_const.AtomicWriteSymbol(d_compressed_buffer, bin_idx,
184+
output_position);
182185
d_temp_row_ptr[e.row_idx] += 1;
183186
});
184187

185188
begin = end;
186189
}
187190
}
188191

189-
void WriteNullValues(EllpackPageImpl*dst,
190-
int device_idx, common::Span<size_t> row_counts)
191-
{
192-
// Write the null values
192+
void WriteNullValues(EllpackPageImpl* dst, int device_idx,
193+
common::Span<size_t> row_counts) {
194+
// Write the null values
193195
auto device_accessor = dst->GetDeviceAccessor(device_idx);
194196
common::CompressedBufferWriter writer(device_accessor.NumSymbols());
195197
auto d_compressed_buffer = dst->gidx_buffer.DevicePointer();
196198
auto row_stride = dst->row_stride;
197-
dh::LaunchN(device_idx, row_stride * dst->n_rows, [=] __device__(size_t idx)
198-
{
199+
dh::LaunchN(device_idx, row_stride * dst->n_rows, [=] __device__(size_t idx) {
199200
auto writer_non_const =
200-
writer; // For some reason this variable gets captured as const
201+
writer; // For some reason this variable gets captured as const
201202
size_t row_idx = idx / row_stride;
202203
size_t row_offset = idx % row_stride;
203-
if (row_offset >= row_counts[row_idx])
204-
{
204+
if (row_offset >= row_counts[row_idx]) {
205205
writer_non_const.AtomicWriteSymbol(d_compressed_buffer,
206206
device_accessor.NullValue(), idx);
207207
}
208208
});
209-
210-
}
209+
}
211210
// Does not currently support metainfo as no on-device data source contains this
212211
// Current implementation assumes a single batch. More batches can
213212
// be supported in future. Does not currently support inferring row/column size
214-
template <typename AdapterT>
213+
template <typename AdapterT>
215214
DeviceDMatrix::DeviceDMatrix(AdapterT* adapter, float missing, int nthread) {
216-
common::HistogramCuts cuts = common::AdapterDeviceSketch(adapter, 256, missing);
217-
auto & batch = adapter->Value();
215+
common::HistogramCuts cuts =
216+
common::AdapterDeviceSketch(adapter, 256, missing);
217+
auto& batch = adapter->Value();
218218
// Work out how many valid entries we have in each row
219-
dh::caching_device_vector<size_t> row_counts(adapter->NumRows() + 1,
220-
0);
221-
common::Span<size_t > row_counts_span( row_counts.data().get(),row_counts.size() );
219+
dh::caching_device_vector<size_t> row_counts(adapter->NumRows() + 1, 0);
220+
common::Span<size_t> row_counts_span(row_counts.data().get(),
221+
row_counts.size());
222222
size_t row_stride =
223223
GetRowCounts(batch, row_counts_span, adapter->DeviceIdx(), missing);
224224

225225
dh::XGBCachingDeviceAllocator<char> alloc;
226226
info.num_nonzero_ = thrust::reduce(thrust::cuda::par(alloc),
227-
row_counts.begin(),
228-
row_counts.end());
227+
row_counts.begin(), row_counts.end());
229228
info.num_col_ = adapter->NumColumns();
230229
info.num_row_ = adapter->NumRows();
231230
ellpack_page_.reset(new EllpackPage());
@@ -239,8 +238,7 @@ DeviceDMatrix::DeviceDMatrix(AdapterT* adapter, float missing, int nthread) {
239238
CopyDataColumnMajor(adapter, batch, ellpack_page_->Impl(), missing);
240239
}
241240

242-
WriteNullValues(ellpack_page_->Impl(), adapter->DeviceIdx(),
243-
row_counts_span);
241+
WriteNullValues(ellpack_page_->Impl(), adapter->DeviceIdx(), row_counts_span);
244242

245243
// Synchronise worker columns
246244
rabit::Allreduce<rabit::op::Max>(&info.num_col_, 1);

src/data/device_dmatrix.h

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
#include <memory>
1313

1414
#include "adapter.h"
15-
#include "simple_dmatrix.h"
1615
#include "simple_batch_iterator.h"
16+
#include "simple_dmatrix.h"
1717

1818
namespace xgboost {
1919
namespace data {
@@ -22,7 +22,7 @@ class DeviceDMatrix : public DMatrix {
2222
public:
2323
template <typename AdapterT>
2424
explicit DeviceDMatrix(AdapterT* adapter, float missing, int nthread);
25-
25+
2626
MetaInfo& Info() override { return info; }
2727

2828
const MetaInfo& Info() const override { return info; }
@@ -31,19 +31,17 @@ class DeviceDMatrix : public DMatrix {
3131

3232
bool EllpackExists() const override { return true; }
3333
bool SparsePageExists() const override { return false; }
34+
3435
private:
35-
BatchSet<SparsePage> GetRowBatches() override
36-
{
36+
BatchSet<SparsePage> GetRowBatches() override {
3737
LOG(FATAL) << "Not implemented.";
3838
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));
3939
}
40-
BatchSet<CSCPage> GetColumnBatches()override
41-
{
40+
BatchSet<CSCPage> GetColumnBatches() override {
4241
LOG(FATAL) << "Not implemented.";
4342
return BatchSet<CSCPage>(BatchIterator<CSCPage>(nullptr));
4443
}
45-
BatchSet<SortedCSCPage> GetSortedColumnBatches()override
46-
{
44+
BatchSet<SortedCSCPage> GetSortedColumnBatches() override {
4745
LOG(FATAL) << "Not implemented.";
4846
return BatchSet<SortedCSCPage>(BatchIterator<SortedCSCPage>(nullptr));
4947
}
@@ -59,4 +57,4 @@ class DeviceDMatrix : public DMatrix {
5957
};
6058
} // namespace data
6159
} // namespace xgboost
62-
#endif // XGBOOST_DATA_DEVICE_DMATRIX_H_
60+
#endif // XGBOOST_DATA_DEVICE_DMATRIX_H_

0 commit comments

Comments
 (0)