3
3
import os
4
4
import unittest
5
5
6
+ import numpy as np
6
7
from mock import Mock , patch
7
8
9
+ try :
10
+ import cupy as cp
11
+ except ImportError :
12
+ pass
13
+
8
14
from dasf .pipeline .executors import LocalExecutor
9
15
from dasf .pipeline .types import TaskExecutorType
10
16
from dasf .utils .funcs import is_gpu_supported
@@ -15,4 +21,48 @@ class TestLocalExecutor(unittest.TestCase):
15
21
def test_local_executor_no_gpu (self ):
16
22
local = LocalExecutor ()
17
23
18
- self .assertTrue (local .dtype , TaskExecutorType .single_cpu )
24
+ self .assertEqual (local .dtype , TaskExecutorType .single_cpu )
25
+ self .assertEqual (local .get_backend (), np )
26
+
27
+ @patch ('dasf.pipeline.executors.wrapper.is_gpu_supported' , Mock (return_value = False ))
28
+ def test_local_executor_no_gpu_but_use_gpu (self ):
29
+ local = LocalExecutor (use_gpu = True )
30
+
31
+ self .assertEqual (local .dtype , TaskExecutorType .single_cpu )
32
+ self .assertEqual (local .get_backend (), np )
33
+
34
+ @unittest .skipIf (not is_gpu_supported (),
35
+ "not supported CUDA in this platform" )
36
+ def test_local_executor_use_gpu (self ):
37
+ local = LocalExecutor (use_gpu = True )
38
+
39
+ self .assertEqual (local .dtype , TaskExecutorType .single_gpu )
40
+ self .assertEqual (local .get_backend (), cp )
41
+
42
+ @unittest .skipIf (not is_gpu_supported (),
43
+ "not supported CUDA in this platform" )
44
+ def test_local_executor_use_gpu_backend_cupy (self ):
45
+ local = LocalExecutor (use_gpu = True , backend = "cupy" )
46
+
47
+ self .assertEqual (local .dtype , TaskExecutorType .single_gpu )
48
+ self .assertEqual (local .get_backend (), cp )
49
+
50
+ @unittest .skipIf (not is_gpu_supported (),
51
+ "not supported CUDA in this platform" )
52
+ def test_local_executor_use_gpu_backend_cupy (self ):
53
+ local = LocalExecutor (backend = "cupy" )
54
+
55
+ self .assertEqual (local .dtype , TaskExecutorType .single_gpu )
56
+ self .assertEqual (local .backend , "cupy" )
57
+ self .assertEqual (local .get_backend (), cp )
58
+
59
+ @unittest .skipIf (not is_gpu_supported (),
60
+ "not supported CUDA in this platform" )
61
+ @patch ('dasf.pipeline.executors.wrapper.rmm.reinitialize' )
62
+ def test_local_executor_with_rmm (self , rmm ):
63
+ local = LocalExecutor (gpu_allocator = "rmm" )
64
+
65
+ self .assertEqual (local .dtype , TaskExecutorType .single_gpu )
66
+ self .assertEqual (local .get_backend (), cp )
67
+
68
+ rmm .assert_called_once_with (managed_memory = True )
0 commit comments