5
5
import os
6
6
import subprocess
7
7
import time
8
- import torch
9
- from torch .utils .cpp_extension import BuildExtension , CppExtension , CUDAExtension
10
8
11
9
version_file = 'basicsr/version.py'
12
10
@@ -117,6 +115,12 @@ def get_requirements(filename='requirements.txt'):
117
115
if __name__ == '__main__' :
118
116
cuda_ext = os .getenv ('BASICSR_EXT' ) # whether compile cuda ext
119
117
if cuda_ext == 'True' :
118
+ try :
119
+ import torch
120
+ from torch .utils .cpp_extension import BuildExtension , CppExtension , CUDAExtension
121
+ except ImportError :
122
+ raise ImportError ('Unable to import torch - torch is needed to build cuda extensions' )
123
+
120
124
ext_modules = [
121
125
make_cuda_ext (
122
126
name = 'deform_conv_ext' ,
@@ -134,8 +138,10 @@ def get_requirements(filename='requirements.txt'):
134
138
sources = ['src/upfirdn2d.cpp' ],
135
139
sources_cuda = ['src/upfirdn2d_kernel.cu' ]),
136
140
]
141
+ setup_kwargs = dict (cmdclass = {'build_ext' : BuildExtension })
137
142
else :
138
143
ext_modules = []
144
+ setup_kwargs = dict ()
139
145
140
146
write_version_py ()
141
147
setup (
@@ -159,8 +165,8 @@ def get_requirements(filename='requirements.txt'):
159
165
'Programming Language :: Python :: 3.8' ,
160
166
],
161
167
license = 'Apache License 2.0' ,
162
- setup_requires = ['cython' , 'numpy' ],
168
+ setup_requires = ['cython' , 'numpy' , 'torch' ],
163
169
install_requires = get_requirements (),
164
170
ext_modules = ext_modules ,
165
- cmdclass = { 'build_ext' : BuildExtension } ,
166
- zip_safe = False )
171
+ zip_safe = False ,
172
+ ** setup_kwargs )
0 commit comments