@@ -146,64 +146,40 @@ __global__ void SharedMemHistKernel(EllpackDeviceAccessor matrix,
146
146
}
147
147
148
148
template <typename GradientSumT>
149
- HistogramLaunchConfig InitGradientHistogram (int device_idx, int n_bins) {
150
- // opt into maximum shared memory for the kernel
151
- int max_shared_memory = dh::MaxSharedMemoryOptin (device_idx);
149
+ void BuildGradientHistogram (EllpackDeviceAccessor const & matrix,
150
+ common::Span<GradientPair const > gpair,
151
+ common::Span<const uint32_t > d_ridx,
152
+ common::Span<GradientSumT> histogram,
153
+ GradientSumT rounding) {
154
+ // decide whether to use shared memory
155
+ int device = 0 ;
156
+ dh::safe_cuda (cudaGetDevice (&device));
157
+ int max_shared_memory = dh::MaxSharedMemoryOptin (device);
158
+ size_t smem_size = sizeof (GradientSumT) * matrix.NumBins ();
159
+ bool shared = smem_size <= max_shared_memory;
160
+ smem_size = shared ? smem_size : 0 ;
161
+
162
+ // opt into maximum shared memory for the kernel if necessary
152
163
auto kernel = SharedMemHistKernel<GradientSumT>;
153
- dh::safe_cuda (cudaFuncSetAttribute
154
- (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
155
- max_shared_memory));
156
-
157
- // find the optimal configuration for the specified bin count
158
- HistogramLaunchConfig config;
159
- config.shared = n_bins * sizeof (GradientSumT) <= max_shared_memory;
160
- config.block_threads = 256 ;
161
- int smem_size = config.shared ? n_bins * sizeof (GradientSumT) : 0 ;
162
-
163
- if (config.shared ) {
164
- // find the optimal number of threads
165
- int max_threads_per_mp = 0 ;
166
- dh::safe_cuda (cudaDeviceGetAttribute
167
- (&max_threads_per_mp,
168
- cudaDevAttrMaxThreadsPerMultiProcessor, device_idx));
169
- int warp_size = 32 ;
170
- int max_kernel_threads_per_mp = 0 ;
171
- for (int block_threads = 128 ; block_threads <= max_threads_per_mp;
172
- block_threads += warp_size) {
173
- int n_kernel_blocks_per_mp = 0 ;
174
- dh::safe_cuda (cudaOccupancyMaxActiveBlocksPerMultiprocessor
175
- (&n_kernel_blocks_per_mp, kernel, block_threads, smem_size));
176
- if (n_kernel_blocks_per_mp * block_threads > max_kernel_threads_per_mp) {
177
- config.block_threads = unsigned (block_threads);
178
- max_kernel_threads_per_mp = n_kernel_blocks_per_mp * block_threads;
179
- }
180
- }
164
+ if (shared) {
165
+ dh::safe_cuda (cudaFuncSetAttribute
166
+ (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
167
+ max_shared_memory));
181
168
}
182
-
169
+
170
+ // determine the launch configuration
171
+ unsigned block_threads = shared ? 1024 : 256 ;
183
172
int n_mps = 0 ;
184
- dh::safe_cuda (cudaDeviceGetAttribute (&n_mps, cudaDevAttrMultiProcessorCount, device_idx ));
173
+ dh::safe_cuda (cudaDeviceGetAttribute (&n_mps, cudaDevAttrMultiProcessorCount, device ));
185
174
int n_blocks_per_mp = 0 ;
186
175
dh::safe_cuda (cudaOccupancyMaxActiveBlocksPerMultiprocessor
187
- (&n_blocks_per_mp, kernel, config.block_threads , smem_size));
188
- config.grid_size = n_blocks_per_mp * n_mps;
189
-
190
- return config;
191
- }
192
-
193
- template <typename GradientSumT>
194
- void BuildGradientHistogram (EllpackDeviceAccessor const & matrix,
195
- common::Span<GradientPair const > gpair,
196
- common::Span<const uint32_t > d_ridx,
197
- common::Span<GradientSumT> histogram,
198
- GradientSumT rounding, const HistogramLaunchConfig& config) {
199
- const size_t smem_size =
200
- config.shared ? sizeof (GradientSumT) * matrix.NumBins () : 0 ;
176
+ (&n_blocks_per_mp, kernel, block_threads, smem_size));
177
+ unsigned grid_size = n_blocks_per_mp * n_mps;
178
+
201
179
auto n_elements = d_ridx.size () * matrix.row_stride ;
202
-
203
- auto kernel = SharedMemHistKernel<GradientSumT>;
204
- dh::LaunchKernel {config.grid_size , config.block_threads , smem_size} (
180
+ dh::LaunchKernel {grid_size, block_threads, smem_size} (
205
181
kernel, matrix, d_ridx, histogram.data (), gpair.data (), n_elements,
206
- rounding, config. shared );
182
+ rounding, shared);
207
183
dh::safe_cuda (cudaGetLastError ());
208
184
}
209
185
@@ -212,20 +188,14 @@ template void BuildGradientHistogram<GradientPair>(
212
188
common::Span<GradientPair const > gpair,
213
189
common::Span<const uint32_t > ridx,
214
190
common::Span<GradientPair> histogram,
215
- GradientPair rounding, const HistogramLaunchConfig& config );
191
+ GradientPair rounding);
216
192
217
193
template void BuildGradientHistogram<GradientPairPrecise>(
218
194
EllpackDeviceAccessor const & matrix,
219
195
common::Span<GradientPair const > gpair,
220
196
common::Span<const uint32_t > ridx,
221
197
common::Span<GradientPairPrecise> histogram,
222
- GradientPairPrecise rounding, const HistogramLaunchConfig& config);
223
-
224
- template HistogramLaunchConfig InitGradientHistogram<GradientPair>
225
- (int device_idx, int n_bins);
226
-
227
- template HistogramLaunchConfig InitGradientHistogram<GradientPairPrecise>
228
- (int device_idx, int n_bins);
198
+ GradientPairPrecise rounding);
229
199
230
200
} // namespace tree
231
201
} // namespace xgboost
0 commit comments