Skip to content

Commit 3974c3f

Browse files
orgoroxinntaoJCBrouwer
authored
Add torch to setup_requires & dynamic import to prevent import errors when installing via pip (#514)
* dynamic import of torch to prevent import error when installing * Update setup.py Co-authored-by: Xintao <[email protected]> Co-authored-by: Hans Brouwer <[email protected]>
1 parent b4f48db commit 3974c3f

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

setup.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
import os
66
import subprocess
77
import time
8-
import torch
9-
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
108

119
version_file = 'basicsr/version.py'
1210

@@ -117,6 +115,12 @@ def get_requirements(filename='requirements.txt'):
117115
if __name__ == '__main__':
118116
cuda_ext = os.getenv('BASICSR_EXT') # whether compile cuda ext
119117
if cuda_ext == 'True':
118+
try:
119+
import torch
120+
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
121+
except ImportError:
122+
raise ImportError('Unable to import torch - torch is needed to build cuda extensions')
123+
120124
ext_modules = [
121125
make_cuda_ext(
122126
name='deform_conv_ext',
@@ -134,8 +138,10 @@ def get_requirements(filename='requirements.txt'):
134138
sources=['src/upfirdn2d.cpp'],
135139
sources_cuda=['src/upfirdn2d_kernel.cu']),
136140
]
141+
setup_kwargs = dict(cmdclass={'build_ext': BuildExtension})
137142
else:
138143
ext_modules = []
144+
setup_kwargs = dict()
139145

140146
write_version_py()
141147
setup(
@@ -159,8 +165,8 @@ def get_requirements(filename='requirements.txt'):
159165
'Programming Language :: Python :: 3.8',
160166
],
161167
license='Apache License 2.0',
162-
setup_requires=['cython', 'numpy'],
168+
setup_requires=['cython', 'numpy', 'torch'],
163169
install_requires=get_requirements(),
164170
ext_modules=ext_modules,
165-
cmdclass={'build_ext': BuildExtension},
166-
zip_safe=False)
171+
zip_safe=False,
172+
**setup_kwargs)

0 commit comments

Comments
 (0)