@@ -175,8 +175,8 @@ GradientBasedSample ExternalMemoryNoSampling::Sample(Context const* ctx,
175
175
return {dmat->Info ().num_row_ , page_.get (), gpair};
176
176
}
177
177
178
- UniformSampling::UniformSampling (EllpackPageImpl const * page , float subsample)
179
- : page_(page) , subsample_(subsample) {}
178
+ UniformSampling::UniformSampling (BatchParam batch_param , float subsample)
179
+ : batch_param_{ std::move (batch_param)} , subsample_(subsample) {}
180
180
181
181
GradientBasedSample UniformSampling::Sample (Context const * ctx, common::Span<GradientPair> gpair,
182
182
DMatrix* dmat) {
@@ -185,7 +185,8 @@ GradientBasedSample UniformSampling::Sample(Context const* ctx, common::Span<Gra
185
185
thrust::replace_if (cuctx->CTP (), dh::tbegin (gpair), dh::tend (gpair),
186
186
thrust::counting_iterator<std::size_t >(0 ),
187
187
BernoulliTrial (common::GlobalRandom ()(), subsample_), GradientPair ());
188
- return {dmat->Info ().num_row_ , page_, gpair};
188
+ auto page = (*dmat->GetBatches <EllpackPage>(ctx, batch_param_).begin ()).Impl ();
189
+ return {dmat->Info ().num_row_ , page, gpair};
189
190
}
190
191
191
192
ExternalMemoryUniformSampling::ExternalMemoryUniformSampling (size_t n_rows,
@@ -331,7 +332,7 @@ GradientBasedSampler::GradientBasedSampler(Context const* ctx, EllpackPageImpl c
331
332
if (is_external_memory) {
332
333
strategy_.reset (new ExternalMemoryUniformSampling (n_rows, batch_param, subsample));
333
334
} else {
334
- strategy_.reset (new UniformSampling (page , subsample));
335
+ strategy_.reset (new UniformSampling (batch_param , subsample));
335
336
}
336
337
break ;
337
338
case TrainParam::kGradientBased :
0 commit comments