Skip to content

Commit 20c95be

Browse files
trivialfishcho3
andauthored
Expand categorical node. (#6028)
Co-authored-by: Philip Hyunsu Cho <[email protected]>
1 parent 9a4e8b1 commit 20c95be

File tree

12 files changed

+341
-104
lines changed

12 files changed

+341
-104
lines changed

R-package/tests/testthat/test_basic.R

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,12 @@ test_that("training continuation works", {
245245
expect_equal(bst$raw, bst2$raw)
246246
expect_equal(dim(bst2$evaluation_log), c(2, 2))
247247
# test continuing from a model in file
248-
xgb.save(bst1, "xgboost.model")
249-
bst2 <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0, xgb_model = "xgboost.model")
248+
xgb.save(bst1, "xgboost.json")
249+
bst2 <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0, xgb_model = "xgboost.json")
250250
if (!windows_flag && !solaris_flag)
251251
expect_equal(bst$raw, bst2$raw)
252252
expect_equal(dim(bst2$evaluation_log), c(2, 2))
253+
file.remove("xgboost.json")
253254
})
254255

255256
test_that("model serialization works", {

R-package/tests/testthat/test_callbacks.R

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -173,16 +173,16 @@ test_that("cb.reset.parameters works as expected", {
173173
})
174174

175175
test_that("cb.save.model works as expected", {
176-
files <- c('xgboost_01.model', 'xgboost_02.model', 'xgboost.model')
176+
files <- c('xgboost_01.json', 'xgboost_02.json', 'xgboost.json')
177177
for (f in files) if (file.exists(f)) file.remove(f)
178178

179179
bst <- xgb.train(param, dtrain, nrounds = 2, watchlist, eta = 1, verbose = 0,
180-
save_period = 1, save_name = "xgboost_%02d.model")
181-
expect_true(file.exists('xgboost_01.model'))
182-
expect_true(file.exists('xgboost_02.model'))
183-
b1 <- xgb.load('xgboost_01.model')
180+
save_period = 1, save_name = "xgboost_%02d.json")
181+
expect_true(file.exists('xgboost_01.json'))
182+
expect_true(file.exists('xgboost_02.json'))
183+
b1 <- xgb.load('xgboost_01.json')
184184
expect_equal(xgb.ntree(b1), 1)
185-
b2 <- xgb.load('xgboost_02.model')
185+
b2 <- xgb.load('xgboost_02.json')
186186
expect_equal(xgb.ntree(b2), 2)
187187

188188
xgb.config(b2) <- xgb.config(bst)
@@ -191,9 +191,9 @@ test_that("cb.save.model works as expected", {
191191

192192
# save_period = 0 saves the last iteration's model
193193
bst <- xgb.train(param, dtrain, nrounds = 2, watchlist, eta = 1, verbose = 0,
194-
save_period = 0)
195-
expect_true(file.exists('xgboost.model'))
196-
b2 <- xgb.load('xgboost.model')
194+
save_period = 0, save_name = 'xgboost.json')
195+
expect_true(file.exists('xgboost.json'))
196+
b2 <- xgb.load('xgboost.json')
197197
xgb.config(b2) <- xgb.config(bst)
198198
expect_equal(bst$raw, b2$raw)
199199

include/xgboost/base.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ using bst_int = int32_t; // NOLINT
109109
using bst_ulong = uint64_t; // NOLINT
110110
/*! \brief float type, used for storing statistics */
111111
using bst_float = float; // NOLINT
112-
112+
/*! \brief Categorical value type. */
113+
using bst_cat_t = int32_t; // NOLINT
113114
/*! \brief Type for data column (feature) index. */
114115
using bst_feature_t = uint32_t; // NOLINT
115116
/*! \brief Type for data row index.

include/xgboost/data.h

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ enum class DataType : uint8_t {
3535
};
3636

3737
enum class FeatureType : uint8_t {
38-
kNumerical
38+
kNumerical,
39+
kCategorical
3940
};
4041

4142
/*!
@@ -309,12 +310,6 @@ class SparsePage {
309310
}
310311
}
311312

312-
/*!
313-
* \brief Push row block into the page.
314-
* \param batch the row batch.
315-
*/
316-
void Push(const dmlc::RowBlock<uint32_t>& batch);
317-
318313
/**
319314
* \brief Pushes external data batch onto this page
320315
*

include/xgboost/span.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,18 @@ namespace common {
101101
} while (0);
102102
#endif // __CUDA_ARCH__
103103

104+
#if defined(__CUDA_ARCH__)
105+
#define SPAN_LT(lhs, rhs) \
106+
if (!((lhs) < (rhs))) { \
107+
printf("%lu < %lu failed\n", static_cast<size_t>(lhs), \
108+
static_cast<size_t>(rhs)); \
109+
asm("trap;"); \
110+
}
111+
#else
112+
#define SPAN_LT(lhs, rhs) \
113+
SPAN_CHECK((lhs) < (rhs))
114+
#endif // defined(__CUDA_ARCH__)
115+
104116
namespace detail {
105117
/*!
106118
* By default, XGBoost uses uint32_t for indexing data. int64_t covers all
@@ -515,7 +527,7 @@ class Span {
515527
}
516528

517529
XGBOOST_DEVICE reference operator[](index_type _idx) const {
518-
SPAN_CHECK(_idx < size());
530+
SPAN_LT(_idx, size());
519531
return data()[_idx];
520532
}
521533

@@ -575,7 +587,6 @@ class Span {
575587
detail::ExtentValue<Extent, Offset, Count>::value> {
576588
SPAN_CHECK((Count == dynamic_extent) ?
577589
(Offset <= size()) : (Offset + Count <= size()));
578-
579590
return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count};
580591
}
581592

include/xgboost/tree_model.h

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,8 @@ class RegTree : public Model {
318318
param.num_deleted = 0;
319319
nodes_.resize(param.num_nodes);
320320
stats_.resize(param.num_nodes);
321+
split_types_.resize(param.num_nodes, FeatureType::kNumerical);
322+
split_categories_segments_.resize(param.num_nodes);
321323
for (int i = 0; i < param.num_nodes; i ++) {
322324
nodes_[i].SetLeaf(0.0f);
323325
nodes_[i].SetParent(kInvalidNodeId);
@@ -412,30 +414,33 @@ class RegTree : public Model {
412414
* \param leaf_right_child The right child index of leaf, by default kInvalidNodeId,
413415
* some updaters use the right child index of leaf as a marker
414416
*/
415-
void ExpandNode(int nid, unsigned split_index, bst_float split_value,
417+
void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value,
416418
bool default_left, bst_float base_weight,
417419
bst_float left_leaf_weight, bst_float right_leaf_weight,
418420
bst_float loss_change, float sum_hess, float left_sum,
419421
float right_sum,
420-
bst_node_t leaf_right_child = kInvalidNodeId) {
421-
int pleft = this->AllocNode();
422-
int pright = this->AllocNode();
423-
auto &node = nodes_[nid];
424-
CHECK(node.IsLeaf());
425-
node.SetLeftChild(pleft);
426-
node.SetRightChild(pright);
427-
nodes_[node.LeftChild()].SetParent(nid, true);
428-
nodes_[node.RightChild()].SetParent(nid, false);
429-
node.SetSplit(split_index, split_value,
430-
default_left);
431-
432-
nodes_[pleft].SetLeaf(left_leaf_weight, leaf_right_child);
433-
nodes_[pright].SetLeaf(right_leaf_weight, leaf_right_child);
434-
435-
this->Stat(nid) = {loss_change, sum_hess, base_weight};
436-
this->Stat(pleft) = {0.0f, left_sum, left_leaf_weight};
437-
this->Stat(pright) = {0.0f, right_sum, right_leaf_weight};
438-
}
422+
bst_node_t leaf_right_child = kInvalidNodeId);
423+
424+
/**
425+
* \brief Expands a leaf node with categories
426+
*
427+
* \param nid The node index to expand.
428+
* \param split_index Feature index of the split.
429+
* \param split_cat The bitset containing categories
430+
* \param default_left True to default left.
431+
* \param base_weight The base weight, before learning rate.
432+
* \param left_leaf_weight The left leaf weight for prediction, modified by learning rate.
433+
* \param right_leaf_weight The right leaf weight for prediction, modified by learning rate.
434+
* \param loss_change The loss change.
435+
* \param sum_hess The sum hess.
436+
* \param left_sum The sum hess of left leaf.
437+
* \param right_sum The sum hess of right leaf.
438+
*/
439+
void ExpandCategorical(bst_node_t nid, unsigned split_index,
440+
common::Span<uint32_t> split_cat, bool default_left,
441+
bst_float base_weight, bst_float left_leaf_weight,
442+
bst_float right_leaf_weight, bst_float loss_change,
443+
float sum_hess, float left_sum, float right_sum);
439444

440445
/*!
441446
* \brief get current depth
@@ -588,6 +593,28 @@ class RegTree : public Model {
588593
* \brief calculate the mean value for each node, required for feature contributions
589594
*/
590595
void FillNodeMeanValues();
596+
/*!
597+
* \brief Get split type for a node.
598+
* \param nidx Index of node.
599+
* \return The type of this split. For leaf node it's always kNumerical.
600+
*/
601+
FeatureType NodeSplitType(bst_node_t nidx) const {
602+
return split_types_.at(nidx);
603+
}
604+
/*!
605+
* \brief Get split types for all nodes.
606+
*/
607+
std::vector<FeatureType> const &GetSplitTypes() const { return split_types_; }
608+
common::Span<uint32_t const> GetSplitCategories() const { return split_categories_; }
609+
auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; }
610+
611+
// The fields of split_categories_segments_[i] are set such that
612+
// the range split_categories_[beg:(beg+size)] stores the bitset for
613+
// the matching categories for the i-th node.
614+
struct Segment {
615+
size_t beg {0};
616+
size_t size {0};
617+
};
591618

592619
private:
593620
// vector of nodes
@@ -597,9 +624,16 @@ class RegTree : public Model {
597624
// stats of nodes
598625
std::vector<RTreeNodeStat> stats_;
599626
std::vector<bst_float> node_mean_values_;
627+
std::vector<FeatureType> split_types_;
628+
629+
// Categories for each internal node.
630+
std::vector<uint32_t> split_categories_;
631+
// Ptr to split categories of each node.
632+
std::vector<Segment> split_categories_segments_;
633+
600634
// allocate a new node,
601635
// !!!!!! NOTE: may cause BUG here, nodes.resize
602-
int AllocNode() {
636+
bst_node_t AllocNode() {
603637
if (param.num_deleted != 0) {
604638
int nid = deleted_nodes_.back();
605639
deleted_nodes_.pop_back();
@@ -612,6 +646,8 @@ class RegTree : public Model {
612646
<< "number of nodes in the tree exceed 2^31";
613647
nodes_.resize(param.num_nodes);
614648
stats_.resize(param.num_nodes);
649+
split_types_.resize(param.num_nodes, FeatureType::kNumerical);
650+
split_categories_segments_.resize(param.num_nodes);
615651
return nd;
616652
}
617653
// delete a tree node, keep the parent field to allow trace back

src/common/bitfield.h

Lines changed: 23 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#if defined(__CUDACC__)
1717
#include <thrust/copy.h>
1818
#include <thrust/device_ptr.h>
19+
#include "device_helpers.cuh"
1920
#endif // defined(__CUDACC__)
2021

2122
#include "xgboost/span.h"
@@ -54,23 +55,24 @@ __forceinline__ __device__ BitFieldAtomicType AtomicAnd(BitFieldAtomicType* addr
5455
*
5556
* \tparam Direction Whether the bits start from left or from right.
5657
*/
57-
template <typename VT, typename Direction>
58+
template <typename VT, typename Direction, bool IsConst = false>
5859
struct BitFieldContainer {
59-
using value_type = VT; // NOLINT
60+
using value_type = std::conditional_t<IsConst, VT const, VT>; // NOLINT
6061
using pointer = value_type*; // NOLINT
6162

6263
static value_type constexpr kValueSize = sizeof(value_type) * 8;
6364
static value_type constexpr kOne = 1; // force correct type.
6465

6566
struct Pos {
66-
value_type int_pos {0};
67-
value_type bit_pos {0};
67+
std::remove_const_t<value_type> int_pos {0};
68+
std::remove_const_t<value_type> bit_pos {0};
6869
};
6970

7071
private:
7172
common::Span<value_type> bits_;
7273
static_assert(!std::is_signed<VT>::value, "Must use unsiged type as underlying storage.");
7374

75+
public:
7476
XGBOOST_DEVICE static Pos ToBitPos(value_type pos) {
7577
Pos pos_v;
7678
if (pos == 0) {
@@ -92,7 +94,7 @@ struct BitFieldContainer {
9294
/*\brief Compute the size of needed memory allocation. The returned value is in terms
9395
* of number of elements with `BitFieldContainer::value_type'.
9496
*/
95-
static size_t ComputeStorageSize(size_t size) {
97+
XGBOOST_DEVICE static size_t ComputeStorageSize(size_t size) {
9698
return common::DivRoundUp(size, kValueSize);
9799
}
98100
#if defined(__CUDA_ARCH__)
@@ -134,19 +136,19 @@ struct BitFieldContainer {
134136
#endif // defined(__CUDA_ARCH__)
135137

136138
#if defined(__CUDA_ARCH__)
137-
__device__ void Set(value_type pos) {
139+
__device__ auto Set(value_type pos) {
138140
Pos pos_v = Direction::Shift(ToBitPos(pos));
139141
value_type& value = bits_[pos_v.int_pos];
140142
value_type set_bit = kOne << pos_v.bit_pos;
141-
static_assert(sizeof(BitFieldAtomicType) == sizeof(value_type), "");
142-
AtomicOr(reinterpret_cast<BitFieldAtomicType*>(&value), set_bit);
143+
using Type = typename dh::detail::AtomicDispatcher<sizeof(value_type)>::Type;
144+
atomicOr(reinterpret_cast<Type *>(&value), set_bit);
143145
}
144146
__device__ void Clear(value_type pos) {
145147
Pos pos_v = Direction::Shift(ToBitPos(pos));
146148
value_type& value = bits_[pos_v.int_pos];
147149
value_type clear_bit = ~(kOne << pos_v.bit_pos);
148-
static_assert(sizeof(BitFieldAtomicType) == sizeof(value_type), "");
149-
AtomicAnd(reinterpret_cast<BitFieldAtomicType*>(&value), clear_bit);
150+
using Type = typename dh::detail::AtomicDispatcher<sizeof(value_type)>::Type;
151+
atomicAnd(reinterpret_cast<Type *>(&value), clear_bit);
150152
}
151153
#else
152154
void Set(value_type pos) {
@@ -165,6 +167,7 @@ struct BitFieldContainer {
165167

166168
XGBOOST_DEVICE bool Check(Pos pos_v) const {
167169
pos_v = Direction::Shift(pos_v);
170+
SPAN_LT(pos_v.int_pos, bits_.size());
168171
value_type const value = bits_[pos_v.int_pos];
169172
value_type const test_bit = kOne << pos_v.bit_pos;
170173
value_type result = test_bit & value;
@@ -179,20 +182,21 @@ struct BitFieldContainer {
179182

180183
XGBOOST_DEVICE pointer Data() const { return bits_.data(); }
181184

182-
friend std::ostream& operator<<(std::ostream& os, BitFieldContainer<VT, Direction> field) {
185+
inline friend std::ostream &
186+
operator<<(std::ostream &os, BitFieldContainer<VT, Direction, IsConst> field) {
183187
os << "Bits " << "storage size: " << field.bits_.size() << "\n";
184188
for (typename common::Span<value_type>::index_type i = 0; i < field.bits_.size(); ++i) {
185-
std::bitset<BitFieldContainer<VT, Direction>::kValueSize> bset(field.bits_[i]);
189+
std::bitset<BitFieldContainer<VT, Direction, IsConst>::kValueSize> bset(field.bits_[i]);
186190
os << bset << "\n";
187191
}
188192
return os;
189193
}
190194
};
191195

192196
// Bits start from left most bits (most significant bit).
193-
template <typename VT>
194-
struct LBitsPolicy : public BitFieldContainer<VT, LBitsPolicy<VT>> {
195-
using Container = BitFieldContainer<VT, LBitsPolicy<VT>>;
197+
template <typename VT, bool IsConst = false>
198+
struct LBitsPolicy : public BitFieldContainer<VT, LBitsPolicy<VT, IsConst>, IsConst> {
199+
using Container = BitFieldContainer<VT, LBitsPolicy<VT, IsConst>, IsConst>;
196200
using Pos = typename Container::Pos;
197201
using value_type = typename Container::value_type; // NOLINT
198202

@@ -215,38 +219,13 @@ struct RBitsPolicy : public BitFieldContainer<VT, RBitsPolicy<VT>> {
215219
}
216220
};
217221

218-
// Format: <Direction>BitField<size of underlying type in bits>, underlying type must be unsigned.
222+
// Format: <Const><Direction>BitField<size of underlying type in bits>, underlying type
223+
// must be unsigned.
219224
using LBitField64 = BitFieldContainer<uint64_t, LBitsPolicy<uint64_t>>;
220225
using RBitField8 = BitFieldContainer<uint8_t, RBitsPolicy<unsigned char>>;
221226

222-
#if defined(__CUDACC__)
223-
224-
template <typename V, typename D>
225-
inline void PrintDeviceBits(std::string name, BitFieldContainer<V, D> field) {
226-
std::cout << "Bits: " << name << std::endl;
227-
std::vector<typename BitFieldContainer<V, D>::value_type> h_field_bits(field.bits_.size());
228-
thrust::copy(thrust::device_ptr<typename BitFieldContainer<V, D>::value_type>(field.bits_.data()),
229-
thrust::device_ptr<typename BitFieldContainer<V, D>::value_type>(
230-
field.bits_.data() + field.bits_.size()),
231-
h_field_bits.data());
232-
BitFieldContainer<V, D> h_field;
233-
h_field.bits_ = {h_field_bits.data(), h_field_bits.data() + h_field_bits.size()};
234-
std::cout << h_field;
235-
}
236-
237-
inline void PrintDeviceStorage(std::string name, common::Span<int32_t> list) {
238-
std::cout << name << std::endl;
239-
std::vector<int32_t> h_list(list.size());
240-
thrust::copy(thrust::device_ptr<int32_t>(list.data()),
241-
thrust::device_ptr<int32_t>(list.data() + list.size()),
242-
h_list.data());
243-
for (auto v : h_list) {
244-
std::cout << v << ", ";
245-
}
246-
std::cout << std::endl;
247-
}
248-
249-
#endif // defined(__CUDACC__)
227+
using LBitField32 = BitFieldContainer<uint32_t, LBitsPolicy<uint32_t>>;
228+
using CLBitField32 = BitFieldContainer<uint32_t, LBitsPolicy<uint32_t, true>, true>;
250229
} // namespace xgboost
251230

252231
#endif // XGBOOST_COMMON_BITFIELD_H_

0 commit comments

Comments
 (0)