Skip to content

Commit f383f76

Browse files
committed
remove page in uniform sampling.
1 parent 1b0dab2 commit f383f76

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed

src/tree/gpu_hist/gradient_based_sampler.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,8 @@ GradientBasedSample ExternalMemoryNoSampling::Sample(Context const* ctx,
175175
return {dmat->Info().num_row_, page_.get(), gpair};
176176
}
177177

178-
UniformSampling::UniformSampling(EllpackPageImpl const* page, float subsample)
179-
: page_(page), subsample_(subsample) {}
178+
UniformSampling::UniformSampling(BatchParam batch_param, float subsample)
179+
: batch_param_{std::move(batch_param)}, subsample_(subsample) {}
180180

181181
GradientBasedSample UniformSampling::Sample(Context const* ctx, common::Span<GradientPair> gpair,
182182
DMatrix* dmat) {
@@ -185,7 +185,8 @@ GradientBasedSample UniformSampling::Sample(Context const* ctx, common::Span<Gra
185185
thrust::replace_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
186186
thrust::counting_iterator<std::size_t>(0),
187187
BernoulliTrial(common::GlobalRandom()(), subsample_), GradientPair());
188-
return {dmat->Info().num_row_, page_, gpair};
188+
auto page = (*dmat->GetBatches<EllpackPage>(ctx, batch_param_).begin()).Impl();
189+
return {dmat->Info().num_row_, page, gpair};
189190
}
190191

191192
ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(size_t n_rows,
@@ -331,7 +332,7 @@ GradientBasedSampler::GradientBasedSampler(Context const* ctx, EllpackPageImpl c
331332
if (is_external_memory) {
332333
strategy_.reset(new ExternalMemoryUniformSampling(n_rows, batch_param, subsample));
333334
} else {
334-
strategy_.reset(new UniformSampling(page, subsample));
335+
strategy_.reset(new UniformSampling(batch_param, subsample));
335336
}
336337
break;
337338
case TrainParam::kGradientBased:

src/tree/gpu_hist/gradient_based_sampler.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,12 @@ class ExternalMemoryNoSampling : public SamplingStrategy {
5757
/*! \brief Uniform sampling in in-memory mode. */
5858
class UniformSampling : public SamplingStrategy {
5959
public:
60-
UniformSampling(EllpackPageImpl const* page, float subsample);
60+
UniformSampling(BatchParam batch_param, float subsample);
6161
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
6262
DMatrix* dmat) override;
6363

6464
private:
65-
EllpackPageImpl const* page_;
65+
BatchParam batch_param_;
6666
float subsample_;
6767
};
6868

src/tree/updater_gpu_hist.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ struct GPUHistMakerDevice {
176176
Context const* ctx_;
177177

178178
public:
179-
EllpackPageImpl const* page;
179+
EllpackPageImpl const* page{nullptr};
180180
common::Span<FeatureType const> feature_types;
181181
BatchParam batch_param;
182182

0 commit comments

Comments
 (0)