-
-
Notifications
You must be signed in to change notification settings - Fork 8.8k
RMM integration plugin #5873
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
RMM integration plugin #5873
Changes from all commits
Commits
Show all changes
48 commits
Select commit
Hold shift + click to select a range
b7a322d
[CI] Add RMM as an optional dependency
hcho3 e15845d
Replace caching allocator with pool allocator from RMM
hcho3 812c209
Revert "Replace caching allocator with pool allocator from RMM"
hcho3 a891112
Use rmm::mr::get_default_resource()
hcho3 b5eb54d
Try setting default resource (doesn't work yet)
hcho3 6abd4c0
Allocate pool_mr in the heap
hcho3 2bdbc23
Prevent leaking pool_mr handle
hcho3 c723632
Separate EXPECT_DEATH() in separate test suite suffixed DeathTest
hcho3 78c2254
Turn off death tests for RMM
hcho3 a520fa1
Address reviewer's feedback
hcho3 a73391c
Prevent leaking of cuda_mr
hcho3 309efc0
Merge remote-tracking branch 'origin/master' into add_rmm
hcho3 fa4ec11
Fix Jenkinsfile syntax
hcho3 871fc29
Remove unnecessary function in Jenkinsfile
hcho3 48051df
[CI] Install NCCL into RMM container
hcho3 c0a05ce
Run Python tests
hcho3 c12e0a6
Try building with RMM, CUDA 10.0
hcho3 a3e0e2f
Do not use RMM for CUDA 10.0 target
hcho3 3aeab69
Actually test for test_rmm flag
hcho3 862d580
Fix TestPythonGPU
hcho3 2a064bf
Use CNMeM allocator, since pool allocator doesn't yet support multiGPU
hcho3 ab4e7b4
Merge branch 'master' into add_rmm
hcho3 dd05d7b
Merge remote-tracking branch 'origin/master' into add_rmm
hcho3 a4da8c5
Merge remote-tracking branch 'upstream/master' into add_rmm
hcho3 789021f
Use 10.0 container to build RMM-enabled XGBoost
hcho3 f27d836
Revert "Use 10.0 container to build RMM-enabled XGBoost"
hcho3 a4b86a9
Fix Jenkinsfile
hcho3 e5eb262
[CI] Assign larger /dev/shm to NCCL
hcho3 4cf7f00
Use 10.2 artifact to run multi-GPU Python tests
hcho3 d023a50
Add CUDA 10.0 -> 11.0 cross-version test; remove CUDA 10.0 target
hcho3 abc64a3
Rename Conda env rmm_test -> gpu_test
hcho3 1e7e42e
Use env var to opt into CNMeM pool for C++ tests
hcho3 f1eeaff
Merge branch 'master' into add_rmm
hcho3 1069ae0
Use identical CUDA version for RMM builds and tests
hcho3 99a7520
Use Pytest fixtures to enable RMM pool in Python tests
hcho3 ecc16ec
Merge remote-tracking branch 'upstream/master' into add_rmm
hcho3 92d1481
Move RMM to plugin/CMakeLists.txt; use PLUGIN_RMM
hcho3 e74fd0d
Use per-device MR; use command arg in gtest
hcho3 2ee04b3
Set CMake prefix path to use Conda env
hcho3 87422a2
Use 0.15 nightly version of RMM
hcho3 9021a75
Remove unnecessary header
hcho3 377580a
Fix a unit test when cudf is missing
2f3c532
Merge remote-tracking branch 'upstream/master' into add_rmm
hcho3 3df7cc3
Add RMM demos
hcho3 567fb33
Remove print()
hcho3 1e63c46
Use HostDeviceVector in GPU predictor
hcho3 ad216c5
Simplify pytest setup; use LocalCUDACluster fixture
hcho3 b4195cd
Address reviewers' commments
hcho3 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
Using XGBoost with RAPIDS Memory Manager (RMM) plugin (EXPERIMENTAL) | ||
==================================================================== | ||
[RAPIDS Memory Manager (RMM)](https://github.com/rapidsai/rmm) library provides a collection of | ||
efficient memory allocators for NVIDIA GPUs. It is now possible to use XGBoost with memory | ||
allocators provided by RMM, by enabling the RMM integration plugin. | ||
|
||
The demos in this directory highlights one RMM allocator in particular: **the pool sub-allocator**. | ||
This allocator addresses the slow speed of `cudaMalloc()` by allocating a large chunk of memory | ||
upfront. Subsequent allocations will draw from the pool of already allocated memory and thus avoid | ||
the overhead of calling `cudaMalloc()` directly. See | ||
[this GTC talk slides](https://on-demand.gputechconf.com/gtc/2015/presentation/S5530-Stephen-Jones.pdf) | ||
for more details. | ||
|
||
Before running the demos, ensure that XGBoost is compiled with the RMM plugin enabled. To do this, | ||
run CMake with option `-DPLUGIN_RMM=ON` (`-DUSE_CUDA=ON` also required): | ||
``` | ||
cmake .. -DUSE_CUDA=ON -DUSE_NCCL=ON -DPLUGIN_RMM=ON | ||
make -j4 | ||
``` | ||
CMake will attempt to locate the RMM library in your build environment. You may choose to build | ||
RMM from the source, or install it using the Conda package manager. If CMake cannot find RMM, you | ||
should specify the location of RMM with the CMake prefix: | ||
``` | ||
# If using Conda: | ||
cmake .. -DUSE_CUDA=ON -DUSE_NCCL=ON -DPLUGIN_RMM=ON -DCMAKE_PREFIX_PATH=$CONDA_PREFIX | ||
# If using RMM installed with a custom location | ||
cmake .. -DUSE_CUDA=ON -DUSE_NCCL=ON -DPLUGIN_RMM=ON -DCMAKE_PREFIX_PATH=/path/to/rmm | ||
``` | ||
|
||
* [Using RMM with a single GPU](./rmm_singlegpu.py) | ||
* [Using RMM with a local Dask cluster consisting of multiple GPUs](./rmm_mgpu_with_dask.py) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import xgboost as xgb | ||
from sklearn.datasets import make_classification | ||
import dask | ||
from dask.distributed import Client | ||
from dask_cuda import LocalCUDACluster | ||
|
||
def main(client): | ||
X, y = make_classification(n_samples=10000, n_informative=5, n_classes=3) | ||
X = dask.array.from_array(X) | ||
y = dask.array.from_array(y) | ||
dtrain = xgb.dask.DaskDMatrix(client, X, label=y) | ||
|
||
params = {'max_depth': 8, 'eta': 0.01, 'objective': 'multi:softprob', 'num_class': 3, | ||
'tree_method': 'gpu_hist'} | ||
output = xgb.dask.train(client, params, dtrain, num_boost_round=100, | ||
evals=[(dtrain, 'train')]) | ||
bst = output['booster'] | ||
history = output['history'] | ||
for i, e in enumerate(history['train']['merror']): | ||
print(f'[{i}] train-merror: {e}') | ||
|
||
if __name__ == '__main__': | ||
# To use RMM pool allocator with a GPU Dask cluster, just add rmm_pool_size option to | ||
# LocalCUDACluster constructor. | ||
with LocalCUDACluster(rmm_pool_size='2GB') as cluster: | ||
with Client(cluster) as client: | ||
main(client) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import xgboost as xgb | ||
import rmm | ||
from sklearn.datasets import make_classification | ||
|
||
# Initialize RMM pool allocator | ||
rmm.reinitialize(pool_allocator=True) | ||
|
||
X, y = make_classification(n_samples=10000, n_informative=5, n_classes=3) | ||
dtrain = xgb.DMatrix(X, label=y) | ||
|
||
params = {'max_depth': 8, 'eta': 0.01, 'objective': 'multi:softprob', 'num_class': 3, | ||
'tree_method': 'gpu_hist'} | ||
# XGBoost will automatically use the RMM pool allocator | ||
bst = xgb.train(params, dtrain, num_boost_round=100, evals=[(dtrain, 'train')]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.