Skip to content

Commit d91879e

Browse files
committed
Support online build on Windows
1 parent a526307 commit d91879e

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

python/setup.py

+20-13
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,13 @@ def get_json_package_info():
181181
def get_llvm_package_info():
182182
system = platform.system()
183183
try:
184-
arch = {"x86_64": "x64", "arm64": "arm64", "aarch64": "arm64"}[platform.machine()]
184+
arch = {"x86_64": "x64", "AMD64": "x64", "arm64": "arm64", "aarch64": "arm64"}[platform.machine()]
185185
except KeyError:
186186
arch = platform.machine()
187187
if system == "Darwin":
188188
system_suffix = f"macos-{arch}"
189+
elif system == "Windows":
190+
system_suffix = f"windows-{arch}"
189191
elif system == "Linux":
190192
if arch == 'arm64':
191193
system_suffix = 'ubuntu-arm64'
@@ -263,6 +265,17 @@ def update_symlink(link_path, source_path):
263265
link_path.symlink_to(source_path, target_is_directory=True)
264266

265267

268+
def download_and_extract_archive(url, extract_path):
269+
with open_url(url) as response:
270+
if url.endswith(".zip"):
271+
file_bytes = BytesIO(response.read())
272+
with zipfile.ZipFile(file_bytes, "r") as file:
273+
file.extractall(path=extract_path)
274+
else:
275+
with tarfile.open(fileobj=response, mode="r|*") as file:
276+
file.extractall(path=extract_path)
277+
278+
266279
def get_thirdparty_packages(packages: list):
267280
triton_cache_path = get_triton_cache_path()
268281
thirdparty_cmake_args = []
@@ -284,14 +297,7 @@ def get_thirdparty_packages(packages: list):
284297
shutil.rmtree(package_root_dir)
285298
os.makedirs(package_root_dir, exist_ok=True)
286299
print(f'downloading and extracting {p.url} ...')
287-
with open_url(p.url) as response:
288-
if p.url.endswith(".zip"):
289-
file_bytes = BytesIO(response.read())
290-
with zipfile.ZipFile(file_bytes, "r") as file:
291-
file.extractall(path=package_root_dir)
292-
else:
293-
with tarfile.open(fileobj=response, mode="r|*") as file:
294-
file.extractall(path=package_root_dir)
300+
download_and_extract_archive(p.url, package_root_dir)
295301
# write version url to package_dir
296302
with open(os.path.join(package_dir, "version.txt"), "w") as f:
297303
f.write(p.url)
@@ -316,9 +322,11 @@ def download_and_copy(name, src_func, dst_path, variable, version, url_func):
316322
system = platform.system()
317323
arch = platform.machine()
318324
# NOTE: This might be wrong for jetson if both grace chips and jetson chips return aarch64
319-
arch = {"arm64": "sbsa", "aarch64": "sbsa"}.get(arch, arch)
320-
supported = {"Linux": "linux", "Darwin": "linux"}
325+
arch = {"AMD64": "x86_64", "arm64": "sbsa", "aarch64": "sbsa"}.get(arch, arch)
326+
supported = {"Linux": "linux", "Darwin": "linux", "Windows": "windows"}
321327
url = url_func(supported[system], arch, version)
328+
if system == "Windows":
329+
url = url.replace(".tar.xz", ".zip")
322330
src_path = src_func(supported[system], arch, version)
323331
tmp_path = os.path.join(triton_cache_path, "nvidia", name) # path to cache the download
324332
dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", dst_path) # final binary path
@@ -331,8 +339,7 @@ def download_and_copy(name, src_func, dst_path, variable, version, url_func):
331339
download = download or curr_version.group(1) != version
332340
if download:
333341
print(f'downloading and extracting {url} ...')
334-
file = tarfile.open(fileobj=open_url(url), mode="r|*")
335-
file.extractall(path=tmp_path)
342+
download_and_extract_archive(url, tmp_path)
336343
os.makedirs(os.path.split(dst_path)[0], exist_ok=True)
337344
print(f'copy {src_path} to {dst_path} ...')
338345
if os.path.isdir(src_path):

0 commit comments

Comments
 (0)