Skip to content

Commit 519cee1

Browse files
authored
Avoid resetting seed for every configuration. (#6349)
1 parent f3a4253 commit 519cee1

File tree

4 files changed

+41
-4
lines changed

4 files changed

+41
-4
lines changed

doc/parameter.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,10 @@ Specify the learning task and the corresponding learning objective. The objectiv
412412

413413
- Random number seed. This parameter is ignored in R package, use `set.seed()` instead.
414414

415+
* ``seed_per_iteration`` [default=false]
416+
417+
- Seed PRNG determnisticly via iterator number, this option will be switched on automatically on distributed mode.
418+
415419
***********************
416420
Command Line Parameters
417421
***********************

include/xgboost/generic_parameters.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@ namespace xgboost {
1414
struct GenericParameter : public XGBoostParameter<GenericParameter> {
1515
// Constant representing the device ID of CPU.
1616
static int32_t constexpr kCpuId = -1;
17+
static int64_t constexpr kDefaultSeed = 0;
1718

1819
public:
1920
// stored random seed
20-
int64_t seed;
21+
int64_t seed { kDefaultSeed };
2122
// whether seed the PRNG each iteration
2223
bool seed_per_iteration;
2324
// number of threads to use if OpenMP is enabled
@@ -46,7 +47,7 @@ struct GenericParameter : public XGBoostParameter<GenericParameter> {
4647

4748
// declare parameters
4849
DMLC_DECLARE_PARAMETER(GenericParameter) {
49-
DMLC_DECLARE_FIELD(seed).set_default(0).describe(
50+
DMLC_DECLARE_FIELD(seed).set_default(kDefaultSeed).describe(
5051
"Random number seed during training.");
5152
DMLC_DECLARE_ALIAS(seed, random_state);
5253
DMLC_DECLARE_FIELD(seed_per_iteration)

src/learner.cc

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ DMLC_REGISTER_PARAMETER(LearnerTrainParam);
202202
DMLC_REGISTER_PARAMETER(GenericParameter);
203203

204204
int constexpr GenericParameter::kCpuId;
205+
int64_t constexpr GenericParameter::kDefaultSeed;
205206

206207
void GenericParameter::ConfigureGpuId(bool require_gpu) {
207208
#if defined(XGBOOST_USE_CUDA)
@@ -239,6 +240,9 @@ using ThreadLocalPredictionCache =
239240
dmlc::ThreadLocalStore<std::map<Learner const *, PredictionContainer>>;
240241

241242
class LearnerConfiguration : public Learner {
243+
private:
244+
std::mutex config_lock_;
245+
242246
protected:
243247
static std::string const kEvalMetric; // NOLINT
244248

@@ -252,7 +256,6 @@ class LearnerConfiguration : public Learner {
252256
LearnerModelParam learner_model_param_;
253257
LearnerTrainParam tparam_;
254258
std::vector<std::string> metric_names_;
255-
std::mutex config_lock_;
256259

257260
public:
258261
explicit LearnerConfiguration(std::vector<std::shared_ptr<DMatrix> > cache)
@@ -283,7 +286,11 @@ class LearnerConfiguration : public Learner {
283286

284287
tparam_.UpdateAllowUnknown(args);
285288
auto mparam_backup = mparam_;
289+
286290
mparam_.UpdateAllowUnknown(args);
291+
292+
auto initialized = generic_parameters_.GetInitialised();
293+
auto old_seed = generic_parameters_.seed;
287294
generic_parameters_.UpdateAllowUnknown(args);
288295
generic_parameters_.CheckDeprecated();
289296

@@ -297,7 +304,9 @@ class LearnerConfiguration : public Learner {
297304
}
298305

299306
// set seed only before the model is initialized
300-
common::GlobalRandom().seed(generic_parameters_.seed);
307+
if (!initialized || generic_parameters_.seed != old_seed) {
308+
common::GlobalRandom().seed(generic_parameters_.seed);
309+
}
301310

302311
// must precede configure gbm since num_features is required for gbm
303312
this->ConfigureNumFeatures();

tests/cpp/test_learner.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <xgboost/version_config.h>
1212
#include "xgboost/json.h"
1313
#include "../../src/common/io.h"
14+
#include "../../src/common/random.h"
1415

1516
namespace xgboost {
1617

@@ -333,4 +334,26 @@ TEST(Learner, Seed) {
333334
ASSERT_EQ(std::to_string(seed),
334335
get<String>(config["learner"]["generic_param"]["seed"]));
335336
}
337+
338+
TEST(Learner, ConstantSeed) {
339+
auto m = RandomDataGenerator{10, 10, 0}.GenerateDMatrix(true);
340+
std::unique_ptr<Learner> learner{Learner::Create({m})};
341+
learner->Configure(); // seed the global random
342+
343+
std::uniform_real_distribution<float> dist;
344+
auto& rng = common::GlobalRandom();
345+
float v_0 = dist(rng);
346+
347+
learner->SetParam("", "");
348+
learner->Configure(); // check configure doesn't change the seed.
349+
float v_1 = dist(rng);
350+
CHECK_NE(v_0, v_1);
351+
352+
{
353+
rng.seed(GenericParameter::kDefaultSeed);
354+
std::uniform_real_distribution<float> dist;
355+
float v_2 = dist(rng);
356+
CHECK_EQ(v_0, v_2);
357+
}
358+
}
336359
} // namespace xgboost

0 commit comments

Comments
 (0)