Skip to content

Commit e15845d

Browse files
committed
Replace caching allocator with pool allocator from RMM
1 parent b7a322d commit e15845d

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

src/common/device_helpers.cuh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ struct XGBDefaultDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
420420
* allocations if verbose. Does not initialise memory on construction.
421421
*/
422422
template <class T>
423-
struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
423+
struct XGBCachingDeviceAllocatorImpl : thrust::device_malloc_allocator<T> {
424424
using pointer = thrust::device_ptr<T>; // NOLINT
425425
template<typename U>
426426
struct rebind // NOLINT
@@ -462,8 +462,13 @@ using XGBDeviceAllocator = detail::XGBDefaultDeviceAllocatorImpl<T>;
462462
/*! Be careful that the initialization constructor is a no-op, which means calling
463463
* `vec.resize(n)` won't initialize the memory region to 0. Instead use
464464
* `vec.resize(n, 0)`*/
465+
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
466+
template <typename T>
467+
using XGBCachingDeviceAllocator = detail::XGBDefaultDeviceAllocatorImpl<T>;
468+
#else // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
465469
template <typename T>
466470
using XGBCachingDeviceAllocator = detail::XGBCachingDeviceAllocatorImpl<T>;
471+
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
467472
/** \brief Specialisation of thrust device vector using custom allocator. */
468473
template <typename T>
469474
using device_vector = thrust::device_vector<T, XGBDeviceAllocator<T>>; // NOLINT

0 commit comments

Comments
 (0)