Skip to content

Commit 8b993ff

Browse files
committed
Forbid pointer to bool cast.
1 parent 076a788 commit 8b993ff

File tree

4 files changed

+12
-10
lines changed

4 files changed

+12
-10
lines changed

src/tree/gpu_hist/gradient_based_sampler.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -318,9 +318,9 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c
318318
return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
319319
}
320320

321-
GradientBasedSampler::GradientBasedSampler(Context const* /*ctx*/, bool is_external_memory,
322-
size_t n_rows, const BatchParam& batch_param,
323-
float subsample, int sampling_method) {
321+
GradientBasedSampler::GradientBasedSampler(Context const* /*ctx*/, size_t n_rows,
322+
const BatchParam& batch_param, float subsample,
323+
int sampling_method, bool is_external_memory) {
324324
// The ctx is kept here for future development of stream-based operations.
325325
monitor_.Init("gradient_based_sampler");
326326

src/tree/gpu_hist/gradient_based_sampler.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ class ExternalMemoryGradientBasedSampling : public SamplingStrategy {
122122
*/
123123
class GradientBasedSampler {
124124
public:
125-
GradientBasedSampler(Context const* ctx, bool is_external_memory, size_t n_rows,
126-
const BatchParam& batch_param, float subsample, int sampling_method);
125+
GradientBasedSampler(Context const* ctx, size_t n_rows, const BatchParam& batch_param,
126+
float subsample, int sampling_method, bool is_external_memory);
127127

128128
/*! \brief Sample from a DMatrix based on the given gradient pairs. */
129129
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair, DMatrix* dmat);

src/tree/updater_gpu_hist.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,8 @@ struct GPUHistMakerDevice {
216216
column_sampler(column_sampler_seed),
217217
interaction_constraints(param, n_features),
218218
batch_param(std::move(_batch_param)) {
219-
sampler.reset(new GradientBasedSampler(ctx, is_external_memory, _n_rows, batch_param,
220-
param.subsample, param.sampling_method));
219+
sampler.reset(new GradientBasedSampler(ctx, _n_rows, batch_param, param.subsample,
220+
param.sampling_method, is_external_memory));
221221
if (!param.monotone_constraints.empty()) {
222222
// Copy assigning an empty vector causes an exception in MSVC debug builds
223223
monotone_constraints = param.monotone_constraints;

tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ void VerifySampling(size_t page_size,
3939
EXPECT_NE(page->n_rows, kRows);
4040
}
4141

42-
GradientBasedSampler sampler(&ctx, page, kRows, param, subsample, sampling_method);
42+
GradientBasedSampler sampler(&ctx, kRows, param, subsample, sampling_method,
43+
!fixed_size_sampling);
4344
auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get());
4445

4546
if (fixed_size_sampling) {
@@ -93,7 +94,7 @@ TEST(GradientBasedSampler, NoSamplingExternalMemory) {
9394
auto page = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();
9495
EXPECT_NE(page->n_rows, kRows);
9596

96-
GradientBasedSampler sampler(&ctx, page, kRows, param, kSubsample, TrainParam::kUniform);
97+
GradientBasedSampler sampler(&ctx, kRows, param, kSubsample, TrainParam::kUniform, true);
9798
auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get());
9899
auto sampled_page = sample.page;
99100
EXPECT_EQ(sample.sample_rows, kRows);
@@ -141,7 +142,8 @@ TEST(GradientBasedSampler, GradientBasedSampling) {
141142
constexpr size_t kPageSize = 0;
142143
constexpr float kSubsample = 0.8;
143144
constexpr int kSamplingMethod = TrainParam::kGradientBased;
144-
VerifySampling(kPageSize, kSubsample, kSamplingMethod);
145+
constexpr bool kFixedSizeSampling = true;
146+
VerifySampling(kPageSize, kSubsample, kSamplingMethod, kFixedSizeSampling);
145147
}
146148

147149
TEST(GradientBasedSampler, GradientBasedSamplingExternalMemory) {

0 commit comments

Comments
 (0)