Skip to content

Commit fc06538

Browse files
committed
Enforce tree order in JSON. (dmlc#5974)
* Make JSON model IO more future proof by using tree id in model loading.
1 parent b13fcfe commit fc06538

File tree

2 files changed

+23
-12
lines changed

2 files changed

+23
-12
lines changed

src/gbm/gbtree_model.cc

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
/*!
2-
* Copyright 2019 by Contributors
2+
* Copyright 2019-2020 by Contributors
33
*/
4+
#include <utility>
5+
46
#include "xgboost/json.h"
57
#include "xgboost/logging.h"
68
#include "gbtree_model.h"
@@ -41,15 +43,14 @@ void GBTreeModel::SaveModel(Json* p_out) const {
4143
auto& out = *p_out;
4244
CHECK_EQ(param.num_trees, static_cast<int>(trees.size()));
4345
out["gbtree_model_param"] = ToJson(param);
44-
std::vector<Json> trees_json;
45-
size_t t = 0;
46-
for (auto const& tree : trees) {
46+
std::vector<Json> trees_json(trees.size());
47+
48+
for (size_t t = 0; t < trees.size(); ++t) {
49+
auto const& tree = trees[t];
4750
Json tree_json{Object()};
4851
tree->SaveModel(&tree_json);
49-
// The field is not used in XGBoost, but might be useful for external project.
50-
tree_json["id"] = Integer(t);
51-
trees_json.emplace_back(tree_json);
52-
t++;
52+
tree_json["id"] = Integer(static_cast<Integer::Int>(t));
53+
trees_json[t] = std::move(tree_json);
5354
}
5455

5556
std::vector<Json> tree_info_json(tree_info.size());
@@ -70,9 +71,10 @@ void GBTreeModel::LoadModel(Json const& in) {
7071
auto const& trees_json = get<Array const>(in["trees"]);
7172
trees.resize(trees_json.size());
7273

73-
for (size_t t = 0; t < trees.size(); ++t) {
74-
trees[t].reset( new RegTree() );
75-
trees[t]->LoadModel(trees_json[t]);
74+
for (size_t t = 0; t < trees_json.size(); ++t) { // NOLINT
75+
auto tree_id = get<Integer>(trees_json[t]["id"]);
76+
trees.at(tree_id).reset(new RegTree());
77+
trees.at(tree_id)->LoadModel(trees_json[t]);
7678
}
7779

7880
tree_info.resize(param.num_trees);

tests/cpp/test_learner.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,16 @@ TEST(Learner, JsonModelIO) {
148148
Json out { Object() };
149149
learner->SaveModel(&out);
150150

151-
learner->LoadModel(out);
151+
dmlc::TemporaryDirectory tmpdir;
152+
153+
std::ofstream fout (tmpdir.path + "/model.json");
154+
fout << out;
155+
fout.close();
156+
157+
auto loaded_str = common::LoadSequentialFile(tmpdir.path + "/model.json");
158+
Json loaded = Json::Load(StringView{loaded_str.c_str(), loaded_str.size()});
159+
160+
learner->LoadModel(loaded);
152161
learner->Configure();
153162

154163
Json new_in { Object() };

0 commit comments

Comments
 (0)