Skip to content

Commit cb77e3a

Browse files
committed
Initial support for one hot categorical split.
1 parent 20c95be commit cb77e3a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1038
-249
lines changed

include/xgboost/feature_map.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ class FeatureMap {
8282
if (!strcmp("q", tname)) return kQuantitive;
8383
if (!strcmp("int", tname)) return kInteger;
8484
if (!strcmp("float", tname)) return kFloat;
85-
LOG(FATAL) << "unknown feature type, use i for indicator and q for quantity";
85+
if (!strcmp("categorical", tname)) return kInteger;
86+
LOG(FATAL) << "unknown feature type, use i for indicator, q for quantity "
87+
"and categorical for categorical split.";
8688
return kIndicator;
8789
}
8890
/*! \brief name of the feature */

include/xgboost/version_config.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#define XGBOOST_VERSION_CONFIG_H_
66

77
#define XGBOOST_VER_MAJOR 1
8-
#define XGBOOST_VER_MINOR 2
8+
#define XGBOOST_VER_MINOR 3
99
#define XGBOOST_VER_PATCH 0
1010

1111
#endif // XGBOOST_VERSION_CONFIG_H_

python-package/xgboost/core.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,8 @@ def __init__(self, data, label=None, weight=None, base_margin=None,
382382
silent=False,
383383
feature_names=None,
384384
feature_types=None,
385-
nthread=None):
385+
nthread=None,
386+
enable_categorical=False):
386387
"""Parameters
387388
----------
388389
data : os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/
@@ -417,6 +418,17 @@ def __init__(self, data, label=None, weight=None, base_margin=None,
417418
Number of threads to use for loading data when parallelization is
418419
applicable. If -1, uses maximum threads available on the system.
419420
421+
enable_categorical: boolean, optional
422+
423+
.. versionadded:: 1.3.0
424+
425+
Experimental support of specializing for categorical features. Do
426+
not set to True unless you are interested in development.
427+
Currently it's only available for `gpu_hist` tree method with 1 vs
428+
rest (one hot) categorical split. Also, JSON serialization format,
429+
`enable_experimental_json_serialization`, `gpu_predictor` and
430+
pandas input are required.
431+
420432
"""
421433
if isinstance(data, list):
422434
raise TypeError('Input data can not be a list.')
@@ -435,7 +447,8 @@ def __init__(self, data, label=None, weight=None, base_margin=None,
435447
data, missing=self.missing,
436448
threads=self.nthread,
437449
feature_names=feature_names,
438-
feature_types=feature_types)
450+
feature_types=feature_types,
451+
enable_categorical=enable_categorical)
439452
assert handle is not None
440453
self.handle = handle
441454

python-package/xgboost/data.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -168,20 +168,24 @@ def _is_pandas_df(data):
168168
}
169169

170170

171-
def _transform_pandas_df(data, feature_names=None, feature_types=None,
171+
def _transform_pandas_df(data, enable_categorical,
172+
feature_names=None, feature_types=None,
172173
meta=None, meta_type=None):
173174
from pandas import MultiIndex, Int64Index
174-
from pandas.api.types import is_sparse
175+
from pandas.api.types import is_sparse, is_categorical
176+
175177
data_dtypes = data.dtypes
176-
if not all(dtype.name in _pandas_dtype_mapper or is_sparse(dtype)
178+
if not all(dtype.name in _pandas_dtype_mapper or is_sparse(dtype) or
179+
(is_categorical(dtype) and enable_categorical)
177180
for dtype in data_dtypes):
178181
bad_fields = [
179182
str(data.columns[i]) for i, dtype in enumerate(data_dtypes)
180183
if dtype.name not in _pandas_dtype_mapper
181184
]
182185

183-
msg = """DataFrame.dtypes for data must be int, float or bool.
184-
Did not expect the data types in fields """
186+
msg = """DataFrame.dtypes for data must be int, float, bool or categorical. When
187+
categorical type is supplied, DMatrix parameter
188+
`enable_categorical` must be set to `True`."""
185189
raise ValueError(msg + ', '.join(bad_fields))
186190

187191
if feature_names is None and meta is None:
@@ -200,6 +204,8 @@ def _transform_pandas_df(data, feature_names=None, feature_types=None,
200204
if is_sparse(dtype):
201205
feature_types.append(_pandas_dtype_mapper[
202206
dtype.subtype.name])
207+
elif is_categorical(dtype) and enable_categorical:
208+
feature_types.append('categorical')
203209
else:
204210
feature_types.append(_pandas_dtype_mapper[dtype.name])
205211

@@ -209,14 +215,19 @@ def _transform_pandas_df(data, feature_names=None, feature_types=None,
209215
meta=meta))
210216

211217
dtype = meta_type if meta_type else 'float'
212-
data = data.values.astype(dtype)
218+
try:
219+
data = data.values.astype(dtype)
220+
except ValueError as e:
221+
raise ValueError('Data must be convertable to float, even ' +
222+
'for categorical data.') from e
213223

214224
return data, feature_names, feature_types
215225

216226

217-
def _from_pandas_df(data, missing, nthread, feature_names, feature_types):
227+
def _from_pandas_df(data, enable_categorical, missing, nthread,
228+
feature_names, feature_types):
218229
data, feature_names, feature_types = _transform_pandas_df(
219-
data, feature_names, feature_types)
230+
data, enable_categorical, feature_names, feature_types)
220231
return _from_numpy_array(data, missing, nthread, feature_names,
221232
feature_types)
222233

@@ -484,7 +495,8 @@ def _has_array_protocol(data):
484495

485496

486497
def dispatch_data_backend(data, missing, threads,
487-
feature_names, feature_types):
498+
feature_names, feature_types,
499+
enable_categorical=False):
488500
'''Dispatch data for DMatrix.'''
489501
if _is_scipy_csr(data):
490502
return _from_scipy_csr(data, missing, feature_names, feature_types)
@@ -500,7 +512,7 @@ def dispatch_data_backend(data, missing, threads,
500512
if _is_tuple(data):
501513
return _from_tuple(data, missing, feature_names, feature_types)
502514
if _is_pandas_df(data):
503-
return _from_pandas_df(data, missing, threads,
515+
return _from_pandas_df(data, enable_categorical, missing, threads,
504516
feature_names, feature_types)
505517
if _is_pandas_series(data):
506518
return _from_pandas_series(data, missing, threads, feature_names,
@@ -624,7 +636,8 @@ def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None):
624636
_meta_from_numpy(data, name, dtype, handle)
625637
return
626638
if _is_pandas_df(data):
627-
data, _, _ = _transform_pandas_df(data, meta=name, meta_type=dtype)
639+
data, _, _ = _transform_pandas_df(data, False, meta=name,
640+
meta_type=dtype)
628641
_meta_from_numpy(data, name, dtype, handle)
629642
return
630643
if _is_pandas_series(data):

src/common/device_helpers.cuh

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ struct AtomicDispatcher<sizeof(uint64_t)> {
8080
using Type = unsigned long long; // NOLINT
8181
static_assert(sizeof(Type) == sizeof(uint64_t), "Unsigned long long should be of size 64 bits.");
8282
};
83+
84+
template <>
85+
struct AtomicDispatcher<sizeof(uint8_t)> {
86+
using Type = uint8_t; // NOLINT
87+
};
8388
} // namespace detail
8489
} // namespace dh
8590

@@ -522,6 +527,17 @@ void CopyDeviceSpanToVector(std::vector<T> *dst, xgboost::common::Span<const T>
522527
cudaMemcpyDeviceToHost));
523528
}
524529

530+
template <class HContainer, class DContainer>
531+
void CopyToD(HContainer const &h, DContainer *d) {
532+
d->resize(h.size());
533+
using HVT = std::remove_cv_t<typename HContainer::value_type>;
534+
using DVT = std::remove_cv_t<typename DContainer::value_type>;
535+
static_assert(std::is_same<HVT, DVT>::value,
536+
"Host and device containers must have same value type.");
537+
dh::safe_cuda(cudaMemcpyAsync(d->data().get(), h.data(), h.size() * sizeof(HVT),
538+
cudaMemcpyHostToDevice));
539+
}
540+
525541
// Keep track of pinned memory allocation
526542
struct PinnedMemory {
527543
void *temp_storage{nullptr};

src/common/hist_util.cu

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "hist_util.cuh"
2525
#include "math.h" // NOLINT
2626
#include "quantile.h"
27+
#include "categorical.h"
2728
#include "xgboost/host_device_vector.h"
2829

2930

@@ -36,6 +37,7 @@ namespace detail {
3637

3738
// Count the entries in each column and exclusive scan
3839
void ExtractCutsSparse(int device, common::Span<SketchContainer::OffsetT const> cuts_ptr,
40+
common::Span<FeatureType const> feature_types,
3941
Span<Entry const> sorted_data,
4042
Span<size_t const> column_sizes_scan,
4143
Span<SketchEntry> out_cuts) {
@@ -48,10 +50,16 @@ void ExtractCutsSparse(int device, common::Span<SketchContainer::OffsetT const>
4850
size_t cut_idx = idx - cuts_ptr[column_idx];
4951
Span<Entry const> column_entries =
5052
sorted_data.subspan(column_sizes_scan[column_idx], column_size);
51-
size_t rank = (column_entries.size() * cut_idx) /
52-
static_cast<float>(num_available_cuts);
53-
out_cuts[idx] = WQSketch::Entry(rank, rank + 1, 1,
54-
column_entries[rank].fvalue);
53+
if (IsCat(feature_types, column_idx)) {
54+
size_t rank = cut_idx;
55+
out_cuts[idx] = WQSketch::Entry(rank, rank + 1, 1,
56+
column_entries[rank].fvalue);
57+
} else {
58+
size_t rank = (column_entries.size() * cut_idx) /
59+
static_cast<float>(num_available_cuts);
60+
out_cuts[idx] = WQSketch::Entry(rank, rank + 1, 1,
61+
column_entries[rank].fvalue);
62+
}
5563
});
5664
}
5765

@@ -196,13 +204,13 @@ void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
196204
}
197205
} // namespace detail
198206

199-
void ProcessBatch(int device, const SparsePage &page, size_t begin, size_t end,
200-
SketchContainer *sketch_container, int num_cuts_per_feature,
201-
size_t num_columns) {
207+
void ProcessBatch(int device, DMatrix const *m, const SparsePage &page,
208+
size_t begin, size_t end, SketchContainer *sketch_container,
209+
int num_cuts_per_feature, size_t num_columns) {
202210
dh::XGBCachingDeviceAllocator<char> alloc;
203211
const auto& host_data = page.data.ConstHostVector();
204212
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
205-
host_data.begin() + end);
213+
host_data.begin() + end);
206214
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
207215
sorted_entries.end(), detail::EntryCompareOp());
208216

@@ -219,13 +227,48 @@ void ProcessBatch(int device, const SparsePage &page, size_t begin, size_t end,
219227
0, sorted_entries.size(),
220228
&cuts_ptr, &column_sizes_scan);
221229

230+
// Removing duplicated entries in categorical features.
231+
dh::caching_device_vector<size_t> new_column_scan(column_sizes_scan.size());
232+
auto d_feature_types = m->Info().feature_types.ConstDeviceSpan();
233+
auto n_uniques = dh::SegmentedUnique(
234+
column_sizes_scan.data().get(),
235+
column_sizes_scan.data().get() + column_sizes_scan.size(),
236+
sorted_entries.begin(), sorted_entries.end(),
237+
new_column_scan.data().get(), sorted_entries.begin(),
238+
[=] __device__(Entry const &l, Entry const &r) {
239+
if (l.index == r.index) {
240+
if (IsCat(d_feature_types, l.index)) {
241+
return l.fvalue == r.fvalue;
242+
}
243+
}
244+
return false;
245+
});
246+
247+
// Renew the column scan and cut scan based on categorical data.
248+
dh::caching_device_vector<SketchContainer::OffsetT> new_cuts_size(num_columns + 1);
249+
auto d_new_cuts_size = dh::ToSpan(new_cuts_size);
250+
auto d_new_columns_ptr = dh::ToSpan(new_column_scan);
251+
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
252+
CHECK_EQ(new_column_scan.size(), new_cuts_size.size());
253+
dh::LaunchN(device, new_column_scan.size() - 1, [=] __device__(size_t idx) {
254+
idx += 1;
255+
if (IsCat(d_feature_types, idx - 1)) {
256+
d_new_cuts_size[idx - 1] =
257+
d_new_columns_ptr[idx] - d_new_columns_ptr[idx - 1];
258+
} else {
259+
d_new_cuts_size[idx - 1] = d_cuts_ptr[idx] - d_cuts_ptr[idx - 1];
260+
}
261+
});
262+
thrust::exclusive_scan(thrust::device, new_cuts_size.cbegin(),
263+
new_cuts_size.cend(), d_cuts_ptr.data());
222264
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
265+
sorted_entries.resize(n_uniques);
223266
dh::caching_device_vector<SketchEntry> cuts(h_cuts_ptr.back());
224-
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();
225-
226267
CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size());
227-
detail::ExtractCutsSparse(device, d_cuts_ptr, dh::ToSpan(sorted_entries),
228-
dh::ToSpan(column_sizes_scan), dh::ToSpan(cuts));
268+
269+
detail::ExtractCutsSparse(device, d_cuts_ptr, d_feature_types,
270+
dh::ToSpan(sorted_entries),
271+
dh::ToSpan(new_column_scan), dh::ToSpan(cuts));
229272

230273
// add cuts into sketches
231274
sorted_entries.clear();
@@ -313,7 +356,9 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
313356
device, num_cuts_per_feature, has_weights);
314357

315358
HistogramCuts cuts;
316-
SketchContainer sketch_container(max_bins, dmat->Info().num_col_,
359+
360+
dmat->Info().feature_types.SetDevice(device);
361+
SketchContainer sketch_container(dmat->Info().feature_types, max_bins, dmat->Info().num_col_,
317362
dmat->Info().num_row_, device);
318363

319364
dmat->Info().weights_.SetDevice(device);
@@ -333,7 +378,7 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
333378
dmat->Info().num_col_,
334379
is_ranking, dh::ToSpan(groups));
335380
} else {
336-
ProcessBatch(device, batch, begin, end, &sketch_container, num_cuts_per_feature,
381+
ProcessBatch(device, dmat, batch, begin, end, &sketch_container, num_cuts_per_feature,
337382
dmat->Info().num_col_);
338383
}
339384
}

src/common/hist_util.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ struct EntryCompareOp {
3838
* \param out_cuts Output cut values
3939
*/
4040
void ExtractCutsSparse(int device, common::Span<SketchContainer::OffsetT const> cuts_ptr,
41+
common::Span<FeatureType const> feature_types,
4142
Span<Entry const> sorted_data,
4243
Span<size_t const> column_sizes_scan,
4344
Span<SketchEntry> out_cuts);
@@ -189,6 +190,7 @@ void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns,
189190
dh::caching_device_vector<SketchEntry> cuts(h_cuts_ptr.back());
190191
// Extract the cuts from all columns concurrently
191192
detail::ExtractCutsSparse(device, d_cuts_ptr,
193+
{},
192194
dh::ToSpan(sorted_entries),
193195
dh::ToSpan(column_sizes_scan),
194196
dh::ToSpan(cuts));

src/common/host_device_vector.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <cstdint>
1111
#include <memory>
1212
#include <utility>
13+
#include "xgboost/tree_model.h"
1314
#include "xgboost/host_device_vector.h"
1415

1516
namespace xgboost {
@@ -176,6 +177,7 @@ template class HostDeviceVector<FeatureType>;
176177
template class HostDeviceVector<Entry>;
177178
template class HostDeviceVector<uint64_t>; // bst_row_t
178179
template class HostDeviceVector<uint32_t>; // bst_feature_t
180+
template class HostDeviceVector<RegTree::Segment>;
179181

180182
#if defined(__APPLE__)
181183
/*

src/common/host_device_vector.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ template class HostDeviceVector<Entry>;
404404
template class HostDeviceVector<uint64_t>; // bst_row_t
405405
template class HostDeviceVector<uint32_t>; // bst_feature_t
406406
template class HostDeviceVector<RegTree::Node>;
407+
template class HostDeviceVector<RegTree::Segment>;
407408

408409
#if defined(__APPLE__)
409410
/*

src/common/observer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class TrainingObserver {
7171

7272
for (size_t i = 0; i < h_vec.size(); ++i) {
7373
OBSERVER_PRINT << h_vec[i] << ", ";
74-
if (i % 8 == 0) {
74+
if (i % 8 == 0 && i != 0) {
7575
OBSERVER_PRINT << OBSERVER_NEWLINE;
7676
}
7777
if ((i + 1) == n) {

0 commit comments

Comments
 (0)