@@ -217,12 +217,96 @@ static int accelerator_cuda_check_vmm(CUdeviceptr dbuf, CUmemorytype *mem_type,
217
217
return 0 ;
218
218
}
219
219
220
+ static int accelerator_cuda_check_mpool (CUdeviceptr dbuf , CUmemorytype * mem_type ,
221
+ int * dev_id )
222
+ {
223
+ #if OPAL_CUDA_VMM_SUPPORT
224
+ static int device_count = -1 ;
225
+ static int mpool_supported = -1 ;
226
+ CUresult result ;
227
+ CUmemoryPool mpool ;
228
+ CUmemAccess_flags flags ;
229
+ CUmemLocation location ;
230
+
231
+ if (mpool_supported <= 0 ) {
232
+ if (mpool_supported == -1 ) {
233
+ if (device_count == -1 ) {
234
+ result = cuDeviceGetCount (& device_count );
235
+ if (result != CUDA_SUCCESS || (0 == device_count )) {
236
+ mpool_supported = 0 ; /* never check again */
237
+ device_count = 0 ;
238
+ return 0 ;
239
+ }
240
+ }
241
+
242
+ /* assume uniformity of devices */
243
+ result = cuDeviceGetAttribute (& mpool_supported ,
244
+ CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED , 0 );
245
+ if (result != CUDA_SUCCESS ) {
246
+ mpool_supported = 0 ;
247
+ }
248
+ }
249
+ if (0 == mpool_supported ) {
250
+ return 0 ;
251
+ }
252
+ }
253
+
254
+ result = cuPointerGetAttribute (& mpool , CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE ,
255
+ dbuf );
256
+ if (CUDA_SUCCESS != result ) {
257
+ return 0 ;
258
+ }
259
+
260
+ /* check if device has access */
261
+ for (int i = 0 ; i < device_count ; i ++ ) {
262
+ location .type = CU_MEM_LOCATION_TYPE_DEVICE ;
263
+ location .id = i ;
264
+ result = cuMemPoolGetAccess (& flags , mpool , & location );
265
+ if ((CUDA_SUCCESS == result ) &&
266
+ (CU_MEM_ACCESS_FLAGS_PROT_READWRITE == flags )) {
267
+ * mem_type = CU_MEMORYTYPE_DEVICE ;
268
+ * dev_id = i ;
269
+ return 1 ;
270
+ }
271
+ }
272
+
273
+ /* host must have access as device access possibility is exhausted */
274
+ * mem_type = CU_MEMORYTYPE_HOST ;
275
+ * dev_id = MCA_ACCELERATOR_NO_DEVICE_ID ;
276
+ return 0 ;
277
+ #endif
278
+
279
+ return 0 ;
280
+ }
281
+
282
+ static int accelerator_cuda_get_primary_context (CUdevice dev_id , CUcontext * pctx )
283
+ {
284
+ CUresult result ;
285
+ unsigned int flags ;
286
+ int active ;
287
+
288
+ result = cuDevicePrimaryCtxGetState (dev_id , & flags , & active );
289
+ if (CUDA_SUCCESS != result ) {
290
+ return OPAL_ERROR ;
291
+ }
292
+
293
+ if (active ) {
294
+ result = cuDevicePrimaryCtxRetain (pctx , dev_id );
295
+ return OPAL_SUCCESS ;
296
+ }
297
+
298
+ return OPAL_ERROR ;
299
+ }
300
+
220
301
static int accelerator_cuda_check_addr (const void * addr , int * dev_id , uint64_t * flags )
221
302
{
222
303
CUresult result ;
223
304
int is_vmm = 0 ;
305
+ int is_mpool_ptr = 0 ;
224
306
int vmm_dev_id = MCA_ACCELERATOR_NO_DEVICE_ID ;
307
+ int mpool_dev_id = MCA_ACCELERATOR_NO_DEVICE_ID ;
225
308
CUmemorytype vmm_mem_type = 0 ;
309
+ CUmemorytype mpool_mem_type = 0 ;
226
310
CUmemorytype mem_type = 0 ;
227
311
CUdeviceptr dbuf = (CUdeviceptr ) addr ;
228
312
CUcontext ctx = NULL , mem_ctx = NULL ;
@@ -235,6 +319,7 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
235
319
* flags = 0 ;
236
320
237
321
is_vmm = accelerator_cuda_check_vmm (dbuf , & vmm_mem_type , & vmm_dev_id );
322
+ is_mpool_ptr = accelerator_cuda_check_mpool (dbuf , & mpool_mem_type , & mpool_dev_id );
238
323
239
324
#if OPAL_CUDA_GET_ATTRIBUTES
240
325
uint32_t is_managed = 0 ;
@@ -268,6 +353,9 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
268
353
if (is_vmm && (vmm_mem_type == CU_MEMORYTYPE_DEVICE )) {
269
354
mem_type = CU_MEMORYTYPE_DEVICE ;
270
355
* dev_id = vmm_dev_id ;
356
+ } else if (is_mpool_ptr && (mpool_mem_type == CU_MEMORYTYPE_DEVICE )) {
357
+ mem_type = CU_MEMORYTYPE_DEVICE ;
358
+ * dev_id = mpool_dev_id ;
271
359
} else {
272
360
/* Host memory, nothing to do here */
273
361
return 0 ;
@@ -278,6 +366,8 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
278
366
} else {
279
367
if (is_vmm ) {
280
368
* dev_id = vmm_dev_id ;
369
+ } else if (is_mpool_ptr ) {
370
+ * dev_id = mpool_dev_id ;
281
371
} else {
282
372
/* query the device from the context */
283
373
* dev_id = accelerator_cuda_get_device_id (mem_ctx );
@@ -296,13 +386,18 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
296
386
if (is_vmm && (vmm_mem_type == CU_MEMORYTYPE_DEVICE )) {
297
387
mem_type = CU_MEMORYTYPE_DEVICE ;
298
388
* dev_id = vmm_dev_id ;
389
+ } else if (is_mpool_ptr && (mpool_mem_type == CU_MEMORYTYPE_DEVICE )) {
390
+ mem_type = CU_MEMORYTYPE_DEVICE ;
391
+ * dev_id = mpool_dev_id ;
299
392
} else {
300
393
/* Host memory, nothing to do here */
301
394
return 0 ;
302
395
}
303
396
} else {
304
397
if (is_vmm ) {
305
398
* dev_id = vmm_dev_id ;
399
+ } else if (is_mpool_ptr ) {
400
+ * dev_id = mpool_dev_id ;
306
401
} else {
307
402
result = cuPointerGetAttribute (& mem_ctx ,
308
403
CU_POINTER_ATTRIBUTE_CONTEXT , dbuf );
@@ -336,14 +431,18 @@ static int accelerator_cuda_check_addr(const void *addr, int *dev_id, uint64_t *
336
431
return OPAL_ERROR ;
337
432
}
338
433
#endif /* OPAL_CUDA_GET_ATTRIBUTES */
339
- if (is_vmm ) {
340
- /* This function is expected to set context if pointer is device
341
- * accessible but VMM allocations have NULL context associated
342
- * which cannot be set against the calling thread */
343
- opal_output (0 ,
344
- "CUDA: unable to set context with the given pointer"
345
- "ptr=%p aborting..." , addr );
346
- return OPAL_ERROR ;
434
+ if (is_vmm || is_mpool_ptr ) {
435
+ if (OPAL_SUCCESS ==
436
+ accelerator_cuda_get_primary_context (
437
+ is_vmm ? vmm_dev_id : mpool_dev_id , & mem_ctx )) {
438
+ /* As VMM/mempool allocations have no context associated
439
+ * with them, check if device primary context can be set */
440
+ } else {
441
+ opal_output (0 ,
442
+ "CUDA: unable to set ctx with the given pointer"
443
+ "ptr=%p aborting..." , addr );
444
+ return OPAL_ERROR ;
445
+ }
347
446
}
348
447
349
448
result = cuCtxSetCurrent (mem_ctx );
0 commit comments