Skip to content

Commit bc267dd

Browse files
authored
Use ptr from mmap for GHistIndexMatrix and ColumnMatrix. (#9315)
* Use ptr from mmap for `GHistIndexMatrix` and `ColumnMatrix`. - Define a resource for holding various types of memory pointers. - Define ref vector for holding resources. - Swap the underlying resources for GHist and ColumnM. - Add documentation for current status. - s390x support is removed. It should work if you can compile XGBoost, all the old workaround code does is to get GCC to compile.
1 parent 96c3071 commit bc267dd

29 files changed

+1448
-509
lines changed

doc/c.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ DMatrix
3333
.. doxygengroup:: DMatrix
3434
:project: xgboost
3535

36+
.. _c_streaming:
37+
3638
Streaming
3739
---------
3840

doc/tutorials/dask.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ on a dask cluster:
5454
y = da.random.random(size=(num_obs, 1), chunks=(1000, 1))
5555
5656
dtrain = xgb.dask.DaskDMatrix(client, X, y)
57+
# or
58+
# dtrain = xgb.dask.DaskQuantileDMatrix(client, X, y)
59+
# `DaskQuantileDMatrix` is available for the `hist` and `gpu_hist` tree method.
5760
5861
output = xgb.dask.train(
5962
client,

doc/tutorials/external_memory.rst

Lines changed: 70 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,15 @@ GPU-based training algorithm. We will introduce them in the following sections.
2222

2323
The feature is still experimental as of 2.0. The performance is not well optimized.
2424

25+
The external memory support has gone through multiple iterations and is still under heavy
26+
development. Like the :py:class:`~xgboost.QuantileDMatrix` with
27+
:py:class:`~xgboost.DataIter`, XGBoost loads data batch-by-batch using a custom iterator
28+
supplied by the user. However, unlike the :py:class:`~xgboost.QuantileDMatrix`, external
29+
memory will not concatenate the batches unless GPU is used (it uses a hybrid approach,
30+
more details follow). Instead, it will cache all batches on the external memory and fetch
31+
them on-demand. Go to the end of the document to see a comparison between
32+
`QuantileDMatrix` and external memory.
33+
2534
*************
2635
Data Iterator
2736
*************
@@ -113,10 +122,11 @@ External memory is supported by GPU algorithms (i.e. when ``tree_method`` is set
113122
``gpu_hist``). However, the algorithm used for GPU is different from the one used for
114123
CPU. When training on a CPU, the tree method iterates through all batches from external
115124
memory for each step of the tree construction algorithm. On the other hand, the GPU
116-
algorithm concatenates all batches into one and stores it in GPU memory. To reduce overall
117-
memory usage, users can utilize subsampling. The good news is that the GPU hist tree
118-
method supports gradient-based sampling, enabling users to set a low sampling rate without
119-
compromising accuracy.
125+
algorithm uses a hybrid approach. It iterates through the data during the beginning of
126+
each iteration and concatenates all batches into one in GPU memory. To reduce overall
127+
memory usage, users can utilize subsampling. The GPU hist tree method supports
128+
`gradient-based sampling`, enabling users to set a low sampling rate without compromising
129+
accuracy.
120130

121131
.. code-block:: python
122132
@@ -134,6 +144,8 @@ see `this paper <https://arxiv.org/abs/2005.09148>`_.
134144
When GPU is running out of memory during iteration on external memory, user might
135145
recieve a segfault instead of an OOM exception.
136146

147+
.. _ext_remarks:
148+
137149
*******
138150
Remarks
139151
*******
@@ -142,17 +154,64 @@ When using external memory with XBGoost, data is divided into smaller chunks so
142154
a fraction of it needs to be stored in memory at any given time. It's important to note
143155
that this method only applies to the predictor data (``X``), while other data, like labels
144156
and internal runtime structures are concatenated. This means that memory reduction is most
145-
effective when dealing with wide datasets where ``X`` is larger compared to other data
146-
like ``y``, while it has little impact on slim datasets.
157+
effective when dealing with wide datasets where ``X`` is significantly larger in size
158+
compared to other data like ``y``, while it has little impact on slim datasets.
159+
160+
As one might expect, fetching data on-demand puts significant pressure on the storage
161+
device. Today's computing device can process way more data than a storage can read in a
162+
single unit of time. The ratio is at order of magnitudes. An GPU is capable of processing
163+
hundred of Gigabytes of floating-point data in a split second. On the other hand, a
164+
four-lane NVMe storage connected to a PCIe-4 slot usually has about 6GB/s of data transfer
165+
rate. As a result, the training is likely to be severely bounded by your storage
166+
device. Before adopting the external memory solution, some back-of-envelop calculations
167+
might help you see whether it's viable. For instance, if your NVMe drive can transfer 4GB
168+
(a fairly practical number) of data per second and you have a 100GB of data in compressed
169+
XGBoost cache (which corresponds to a dense float32 numpy array with the size of 200GB,
170+
give or take). A tree with depth 8 needs at least 16 iterations through the data when the
171+
parameter is right. You need about 14 minutes to train a single tree without accounting
172+
for some other overheads and assume the computation overlaps with the IO. If your dataset
173+
happens to have TB-level size, then you might need thousands of trees to get a generalized
174+
model. These calculations can help you get an estimate on the expected training time.
175+
176+
However, sometimes we can ameliorate this limitation. One should also consider that the OS
177+
(mostly talking about the Linux kernel) can usually cache the data on host memory. It only
178+
evicts pages when new data comes in and there's no room left. In practice, at least some
179+
portion of the data can persist on the host memory throughout the entire training
180+
session. We are aware of this cache when optimizing the external memory fetcher. The
181+
compressed cache is usually smaller than the raw input data, especially when the input is
182+
dense without any missing value. If the host memory can fit a significant portion of this
183+
compressed cache, then the performance should be decent after initialization. Our
184+
development so far focus on two fronts of optimization for external memory:
185+
186+
- Avoid iterating through the data whenever appropriate.
187+
- If the OS can cache the data, the performance should be close to in-core training.
147188

148189
Starting with XGBoost 2.0, the implementation of external memory uses ``mmap``. It is not
149-
yet tested against system errors like disconnected network devices (`SIGBUS`). Also, it's
150-
worth noting that most tests have been conducted on Linux distributions.
190+
tested against system errors like disconnected network devices (`SIGBUS`). In the face of
191+
a bus error, you will see a hard crash and need to clean up the cache files. If the
192+
training session might take a long time and you are using solutions like NVMe-oF, we
193+
recommend checkpointing your model periodically. Also, it's worth noting that most tests
194+
have been conducted on Linux distributions.
151195

152-
Another important point to keep in mind is that creating the initial cache for XGBoost may
153-
take some time. The interface to external memory is through custom iterators, which may or
154-
may not be thread-safe. Therefore, initialization is performed sequentially.
155196

197+
Another important point to keep in mind is that creating the initial cache for XGBoost may
198+
take some time. The interface to external memory is through custom iterators, which we can
199+
not assume to be thread-safe. Therefore, initialization is performed sequentially. Using
200+
the `xgboost.config_context` with `verbosity=2` can give you some information on what
201+
XGBoost is doing during the wait if you don't mind the extra output.
202+
203+
*******************************
204+
Compared to the QuantileDMatrix
205+
*******************************
206+
207+
Passing an iterator to the :py:class:`~xgboost.QuantileDmatrix` enables direct
208+
construction of `QuantileDmatrix` with data chunks. On the other hand, if it's passed to
209+
:py:class:`~xgboost.DMatrix`, it instead enables the external memory feature. The
210+
:py:class:`~xgboost.QuantileDmatrix` concatenates the data on memory after compression and
211+
doesn't fetch data during training. On the other hand, the external memory `DMatrix`
212+
fetches data batches from external memory on-demand. Use the `QuantileDMatrix` (with
213+
iterator if necessary) when you can fit most of your data in memory. The training would be
214+
an order of magnitute faster than using external memory.
156215

157216
****************
158217
Text File Inputs

doc/tutorials/index.rst

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,22 @@ See `Awesome XGBoost <https://github.com/dmlc/xgboost/tree/master/demo>`_ for mo
1111

1212
model
1313
saving_model
14+
learning_to_rank
15+
dart
16+
monotonic
17+
feature_interaction_constraint
18+
aft_survival_analysis
19+
categorical
20+
multioutput
21+
rf
1422
kubernetes
1523
Distributed XGBoost with XGBoost4J-Spark <https://xgboost.readthedocs.io/en/latest/jvm/xgboost4j_spark_tutorial.html>
1624
Distributed XGBoost with XGBoost4J-Spark-GPU <https://xgboost.readthedocs.io/en/latest/jvm/xgboost4j_spark_gpu_tutorial.html>
1725
dask
1826
spark_estimator
1927
ray
20-
dart
21-
monotonic
22-
rf
23-
feature_interaction_constraint
24-
learning_to_rank
25-
aft_survival_analysis
28+
external_memory
2629
c_api_tutorial
2730
input_format
2831
param_tuning
29-
external_memory
3032
custom_metric_obj
31-
categorical
32-
multioutput

doc/tutorials/param_tuning.rst

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,46 @@ This can affect the training of XGBoost model, and there are two ways to improve
5858

5959
- In such a case, you cannot re-balance the dataset
6060
- Set parameter ``max_delta_step`` to a finite number (say 1) to help convergence
61+
62+
63+
*********************
64+
Reducing Memory Usage
65+
*********************
66+
67+
If you are using a HPO library like :py:class:`sklearn.model_selection.GridSearchCV`,
68+
please control the number of threads it can use. It's best to let XGBoost to run in
69+
parallel instead of asking `GridSearchCV` to run multiple experiments at the same
70+
time. For instance, creating a fold of data for cross validation can consume a significant
71+
amount of memory:
72+
73+
.. code-block:: python
74+
75+
# This creates a copy of dataset. X and X_train are both in memory at the same time.
76+
77+
# This happens for every thread at the same time if you run `GridSearchCV` with
78+
# `n_jobs` larger than 1
79+
80+
X_train, X_test, y_train, y_test = train_test_split(X, y)
81+
82+
.. code-block:: python
83+
84+
df = pd.DataFrame()
85+
# This creates a new copy of the dataframe, even if you specify the inplace parameter
86+
new_df = df.drop(...)
87+
88+
.. code-block:: python
89+
90+
array = np.array(...)
91+
# This may or may not make a copy of the data, depending on the type of the data
92+
array.astype(np.float32)
93+
94+
.. code-block::
95+
96+
# np by default uses double, do you actually need it?
97+
array = np.array(...)
98+
99+
You can find some more specific memory reduction practices scattered through the documents
100+
For instances: :doc:`/tutorials/dask`, :doc:`/gpu/index`,
101+
:doc:`/contrib/scaling`. However, before going into these, being conscious about making
102+
data copies is a good starting point. It usually consumes a lot more memory than people
103+
expect.

rabit/include/rabit/internal/io.h

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
#include "rabit/internal/utils.h"
2020
#include "rabit/serializable.h"
2121

22-
namespace rabit {
23-
namespace utils {
22+
namespace rabit::utils {
2423
/*! \brief re-use definition of dmlc::SeekStream */
2524
using SeekStream = dmlc::SeekStream;
2625
/**
@@ -31,9 +30,6 @@ struct MemoryFixSizeBuffer : public SeekStream {
3130
// similar to SEEK_END in libc
3231
static std::size_t constexpr kSeekEnd = std::numeric_limits<std::size_t>::max();
3332

34-
protected:
35-
MemoryFixSizeBuffer() = default;
36-
3733
public:
3834
/**
3935
* @brief Ctor
@@ -68,7 +64,7 @@ struct MemoryFixSizeBuffer : public SeekStream {
6864
* @brief Current position in the buffer (stream).
6965
*/
7066
std::size_t Tell() override { return curr_ptr_; }
71-
virtual bool AtEnd() const { return curr_ptr_ == buffer_size_; }
67+
[[nodiscard]] virtual bool AtEnd() const { return curr_ptr_ == buffer_size_; }
7268

7369
protected:
7470
/*! \brief in memory buffer */
@@ -119,6 +115,5 @@ struct MemoryBufferStream : public SeekStream {
119115
/*! \brief current pointer */
120116
size_t curr_ptr_;
121117
}; // class MemoryBufferStream
122-
} // namespace utils
123-
} // namespace rabit
118+
} // namespace rabit::utils
124119
#endif // RABIT_INTERNAL_IO_H_

src/common/column_matrix.cc

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,27 @@
1-
/*!
2-
* Copyright 2017-2022 by XGBoost Contributors
1+
/**
2+
* Copyright 2017-2023, XGBoost Contributors
33
* \brief Utility for fast column-wise access
44
*/
55
#include "column_matrix.h"
66

7-
namespace xgboost {
8-
namespace common {
7+
#include <algorithm> // for transform
8+
#include <cstddef> // for size_t
9+
#include <cstdint> // for uint64_t, uint8_t
10+
#include <limits> // for numeric_limits
11+
#include <type_traits> // for remove_reference_t
12+
#include <vector> // for vector
13+
14+
#include "../data/gradient_index.h" // for GHistIndexMatrix
15+
#include "io.h" // for AlignedResourceReadStream, AlignedFileWriteStream
16+
#include "xgboost/base.h" // for bst_feaature_t
17+
#include "xgboost/span.h" // for Span
18+
19+
namespace xgboost::common {
920
void ColumnMatrix::InitStorage(GHistIndexMatrix const& gmat, double sparse_threshold) {
1021
auto const nfeature = gmat.Features();
1122
const size_t nrow = gmat.Size();
1223
// identify type of each column
13-
type_.resize(nfeature);
24+
type_ = common::MakeFixedVecWithMalloc(nfeature, ColumnType{});
1425

1526
uint32_t max_val = std::numeric_limits<uint32_t>::max();
1627
for (bst_feature_t fid = 0; fid < nfeature; ++fid) {
@@ -34,7 +45,7 @@ void ColumnMatrix::InitStorage(GHistIndexMatrix const& gmat, double sparse_thres
3445

3546
// want to compute storage boundary for each feature
3647
// using variants of prefix sum scan
37-
feature_offsets_.resize(nfeature + 1);
48+
feature_offsets_ = common::MakeFixedVecWithMalloc(nfeature + 1, std::size_t{0});
3849
size_t accum_index = 0;
3950
feature_offsets_[0] = accum_index;
4051
for (bst_feature_t fid = 1; fid < nfeature + 1; ++fid) {
@@ -49,17 +60,63 @@ void ColumnMatrix::InitStorage(GHistIndexMatrix const& gmat, double sparse_thres
4960
SetTypeSize(gmat.MaxNumBinPerFeat());
5061
auto storage_size =
5162
feature_offsets_.back() * static_cast<std::underlying_type_t<BinTypeSize>>(bins_type_size_);
52-
index_.resize(storage_size, 0);
63+
64+
index_ = common::MakeFixedVecWithMalloc(storage_size, std::uint8_t{0});
65+
5366
if (!all_dense_column) {
54-
row_ind_.resize(feature_offsets_[nfeature]);
67+
row_ind_ = common::MakeFixedVecWithMalloc(feature_offsets_[nfeature], std::size_t{0});
5568
}
5669

5770
// store least bin id for each feature
5871
index_base_ = const_cast<uint32_t*>(gmat.cut.Ptrs().data());
5972

6073
any_missing_ = !gmat.IsDense();
6174

62-
missing_flags_.clear();
75+
missing_ = MissingIndicator{0, false};
76+
}
77+
78+
// IO procedures for external memory.
79+
bool ColumnMatrix::Read(AlignedResourceReadStream* fi, uint32_t const* index_base) {
80+
if (!common::ReadVec(fi, &index_)) {
81+
return false;
82+
}
83+
if (!common::ReadVec(fi, &type_)) {
84+
return false;
85+
}
86+
if (!common::ReadVec(fi, &row_ind_)) {
87+
return false;
88+
}
89+
if (!common::ReadVec(fi, &feature_offsets_)) {
90+
return false;
91+
}
92+
93+
if (!common::ReadVec(fi, &missing_.storage)) {
94+
return false;
95+
}
96+
missing_.InitView();
97+
98+
index_base_ = index_base;
99+
if (!fi->Read(&bins_type_size_)) {
100+
return false;
101+
}
102+
if (!fi->Read(&any_missing_)) {
103+
return false;
104+
}
105+
return true;
106+
}
107+
108+
std::size_t ColumnMatrix::Write(AlignedFileWriteStream* fo) const {
109+
std::size_t bytes{0};
110+
111+
bytes += common::WriteVec(fo, index_);
112+
bytes += common::WriteVec(fo, type_);
113+
bytes += common::WriteVec(fo, row_ind_);
114+
bytes += common::WriteVec(fo, feature_offsets_);
115+
bytes += common::WriteVec(fo, missing_.storage);
116+
117+
bytes += fo->Write(bins_type_size_);
118+
bytes += fo->Write(any_missing_);
119+
120+
return bytes;
63121
}
64-
} // namespace common
65-
} // namespace xgboost
122+
} // namespace xgboost::common

0 commit comments

Comments
 (0)