Skip to content

Commit 6dc3066

Browse files
committed
Adding more tests to wrapper executor
Signed-off-by: Julio Faracco <[email protected]>
1 parent 82e12a4 commit 6dc3066

File tree

2 files changed

+58
-6
lines changed

2 files changed

+58
-6
lines changed

dasf/pipeline/executors/wrapper.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
#!/usr/bin/env python3
22

3+
import numpy as np
4+
35
try:
46
import cupy as cp
57
import rmm
8+
from rmm.allocators.cupy import rmm_cupy_allocator
69
except ImportError: # pragma: no cover
710
pass
811

@@ -30,7 +33,7 @@ def __init__(self,
3033

3134
if gpu_allocator == "rmm" and self.dtype == TaskExecutorType.single_gpu:
3235
rmm.reinitialize(managed_memory=True)
33-
cp.cuda.set_allocator(rmm.rmm_cupy_allocator)
36+
cp.cuda.set_allocator(rmm_cupy_allocator)
3437

3538
@property
3639
def ngpus(self) -> int:
@@ -47,11 +50,10 @@ def post_run(self, pipeline):
4750
pass
4851

4952
def get_backend(self):
50-
if self.backend == "numpy" and \
51-
self.dtype == TaskExecutorType.single_gpu:
52-
return eval("cupy")
53+
if self.dtype == TaskExecutorType.single_gpu:
54+
return eval("cp")
5355

54-
return eval(self.backend)
56+
return eval("np")
5557

5658
def execute(self, fn, *args, **kwargs):
5759
if get_backend_supported(fn):

tests/pipeline/executors/test_wrapper.py

+51-1
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,14 @@
33
import os
44
import unittest
55

6+
import numpy as np
67
from mock import Mock, patch
78

9+
try:
10+
import cupy as cp
11+
except ImportError:
12+
pass
13+
814
from dasf.pipeline.executors import LocalExecutor
915
from dasf.pipeline.types import TaskExecutorType
1016
from dasf.utils.funcs import is_gpu_supported
@@ -15,4 +21,48 @@ class TestLocalExecutor(unittest.TestCase):
1521
def test_local_executor_no_gpu(self):
1622
local = LocalExecutor()
1723

18-
self.assertTrue(local.dtype, TaskExecutorType.single_cpu)
24+
self.assertEqual(local.dtype, TaskExecutorType.single_cpu)
25+
self.assertEqual(local.get_backend(), np)
26+
27+
@patch('dasf.pipeline.executors.wrapper.is_gpu_supported', Mock(return_value=False))
28+
def test_local_executor_no_gpu_but_use_gpu(self):
29+
local = LocalExecutor(use_gpu=True)
30+
31+
self.assertEqual(local.dtype, TaskExecutorType.single_cpu)
32+
self.assertEqual(local.get_backend(), np)
33+
34+
@unittest.skipIf(not is_gpu_supported(),
35+
"not supported CUDA in this platform")
36+
def test_local_executor_use_gpu(self):
37+
local = LocalExecutor(use_gpu=True)
38+
39+
self.assertEqual(local.dtype, TaskExecutorType.single_gpu)
40+
self.assertEqual(local.get_backend(), cp)
41+
42+
@unittest.skipIf(not is_gpu_supported(),
43+
"not supported CUDA in this platform")
44+
def test_local_executor_use_gpu_backend_cupy(self):
45+
local = LocalExecutor(use_gpu=True, backend="cupy")
46+
47+
self.assertEqual(local.dtype, TaskExecutorType.single_gpu)
48+
self.assertEqual(local.get_backend(), cp)
49+
50+
@unittest.skipIf(not is_gpu_supported(),
51+
"not supported CUDA in this platform")
52+
def test_local_executor_use_gpu_backend_cupy(self):
53+
local = LocalExecutor(backend="cupy")
54+
55+
self.assertEqual(local.dtype, TaskExecutorType.single_gpu)
56+
self.assertEqual(local.backend, "cupy")
57+
self.assertEqual(local.get_backend(), cp)
58+
59+
@unittest.skipIf(not is_gpu_supported(),
60+
"not supported CUDA in this platform")
61+
@patch('dasf.pipeline.executors.wrapper.rmm.reinitialize')
62+
def test_local_executor_with_rmm(self, rmm):
63+
local = LocalExecutor(gpu_allocator="rmm")
64+
65+
self.assertEqual(local.dtype, TaskExecutorType.single_gpu)
66+
self.assertEqual(local.get_backend(), cp)
67+
68+
rmm.assert_called_once_with(managed_memory=True)

0 commit comments

Comments
 (0)