Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 58fe42f

Browse files
committed
Enable the easy download of the deployment.tar.gz (#379)
1 parent f269c02 commit 58fe42f

File tree

3 files changed

+97
-8
lines changed

3 files changed

+97
-8
lines changed

src/sparsezoo/model/model.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
save_outputs_to_tar,
3232
)
3333
from sparsezoo.objects import (
34-
AliasedSelectDirectory,
3534
Directory,
3635
File,
3736
NumpyDirectory,
@@ -138,12 +137,26 @@ def __init__(self, source: str, download_path: Optional[str] = None):
138137
files, directory_class=Directory, display_name="sample-labels"
139138
)
140139

141-
self.deployment: AliasedSelectDirectory = self._directory_from_files(
140+
self.deployment: SelectDirectory = self._directory_from_files(
142141
files,
143-
directory_class=AliasedSelectDirectory,
142+
directory_class=SelectDirectory,
144143
display_name="deployment",
145-
download_alias="deployment.tar.gz",
146144
stub_params=self.stub_params,
145+
allow_multiple_outputs=True,
146+
)
147+
148+
if isinstance(self.deployment, list):
149+
# if there are multiple deployment directories
150+
# (this may happen due to the presence of both
151+
# - deployment directory
152+
# - deployment.tar.gz file
153+
# we need to choose one (they are identical)
154+
self.deployment = self.deployment[0]
155+
156+
self.deployment_tar: SelectDirectory = self._directory_from_files(
157+
files,
158+
directory_class=SelectDirectory,
159+
display_name="deployment.tar.gz",
147160
)
148161

149162
self.onnx_folder: Directory = self._directory_from_files(
@@ -196,6 +209,7 @@ def __init__(self, source: str, download_path: Optional[str] = None):
196209
self._files_dictionary = {
197210
"training": self.training,
198211
"deployment": self.deployment,
212+
"deployment.tar.gz": self.deployment_tar,
199213
"onnx_folder": self.onnx_folder,
200214
"logs": self.logs,
201215
"sample_originals": self.sample_originals,
@@ -233,9 +247,9 @@ def deployment_directory_path(self) -> str:
233247
deployment directory if compressed
234248
"""
235249
# trigger initial download if not downloaded
236-
self.deployment.path
237-
if self.deployment.is_archive:
238-
self.deployment.unzip()
250+
self.deployment_tar.path
251+
if self.deployment_tar.is_archive:
252+
self.deployment_tar.unzip()
239253

240254
return self.deployment.path
241255

@@ -310,6 +324,12 @@ def download(
310324
else:
311325
downloads = []
312326
for key, file in self._files_dictionary.items():
327+
if key == "deployment":
328+
# skip the download of the deployment directory
329+
# since identical files will be downloaded
330+
# in the deployment_tar
331+
_LOGGER.debug(f"Intentionally skipping downloading the file {key}")
332+
continue
313333
if file is not None:
314334
# save all the files to a temporary directory
315335
downloads.append(self._download(file, download_path))

src/sparsezoo/objects/directory.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,9 +298,22 @@ def unzip(self, extract_directory: Optional[str] = None, force: bool = False):
298298
member.name = os.path.basename(member.name)
299299
tar.extract(member=member, path=path)
300300
files.append(
301-
File(name=member.name, path=os.path.join(path, member.name))
301+
File(
302+
name=member.name,
303+
path=os.path.join(path, member.name),
304+
parent_directory=path,
305+
)
302306
)
303307
tar.close()
308+
# if path already exists, then the tar archive has already been unzipped
309+
# and we can just use the files in the directory
310+
elif os.path.exists(path):
311+
for file in os.listdir(path):
312+
files.append(
313+
File(
314+
name=file, path=os.path.join(path, file), parent_directory=path
315+
)
316+
)
304317

305318
self.name = name
306319
self.files = files

tests/sparsezoo/model/test_model.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@
2323
import pytest
2424

2525
from sparsezoo import Model
26+
from sparsezoo.objects.directories import SelectDirectory
2627

2728

2829
files_ic = {
2930
"training",
31+
"deployment.tar.gz",
3032
"deployment",
3133
"logs",
3234
"onnx",
@@ -182,6 +184,10 @@ def setup(self, stub, clone_sample_outputs, expected_files):
182184
temp_dir = tempfile.TemporaryDirectory(dir="/tmp")
183185
model = Model(stub, temp_dir.name)
184186
model.download()
187+
# since downloading the `deployment` file is
188+
# disabled by default, we need to do it
189+
# explicitly
190+
model.deployment.download()
185191
self._add_mock_files(temp_dir.name, clone_sample_outputs=clone_sample_outputs)
186192
model = Model(temp_dir.name)
187193

@@ -329,6 +335,56 @@ def test_model_gz_extraction_from_local_files(stub: str):
329335
shutil.rmtree(temp_dir.name)
330336

331337

338+
@pytest.mark.parametrize(
339+
"stub",
340+
[
341+
"zoo:cv/classification/mobilenet_v1-1.0/pytorch/sparseml/"
342+
"imagenet/pruned-moderate",
343+
],
344+
)
345+
def test_model_deployment_directory(stub):
346+
temp_dir = tempfile.TemporaryDirectory(dir="/tmp")
347+
expected_deployment_files = ["model.onnx"]
348+
349+
model = Model(stub, temp_dir.name)
350+
assert model.deployment_tar.is_archive
351+
# download and extract deployment tar
352+
deployment_dir_path = model.deployment_directory_path
353+
354+
# deployment and deployment_tar should be point to the same files
355+
assert deployment_dir_path == model.deployment_tar.path == model.deployment.path
356+
# make sure that the model contains expected files
357+
assert set(os.listdir(temp_dir.name)) == {"deployment.tar.gz", "deployment"}
358+
assert (
359+
os.listdir(os.path.join(temp_dir.name, "deployment"))
360+
== expected_deployment_files
361+
)
362+
363+
assert isinstance(model.deployment, SelectDirectory)
364+
# TODO: this should be 1. However, the API is returning for `deployment` file type
365+
# both `model.onnx` and `deployment/model.onnx`.
366+
# This should probably be fixed on the API side
367+
assert (
368+
len(model.deployment.files) == 2
369+
) # should be == len(expected_deployment_files)
370+
371+
assert isinstance(model.deployment_tar, SelectDirectory)
372+
assert len(model.deployment_tar.files) == len(expected_deployment_files)
373+
assert not model.deployment_tar.is_archive
374+
375+
# test recreating the model from the local files
376+
model = Model(temp_dir.name)
377+
378+
assert isinstance(model.deployment, SelectDirectory)
379+
assert len(model.deployment.files) == len(expected_deployment_files)
380+
381+
assert isinstance(model.deployment_tar, SelectDirectory)
382+
assert len(model.deployment_tar.files) == len(expected_deployment_files)
383+
assert not model.deployment_tar.is_archive
384+
385+
shutil.rmtree(temp_dir.name)
386+
387+
332388
def _extraction_test_helper(model: Model):
333389
# download and extract model.onnx.tar.gz
334390
# path should point to extracted model.onnx file

0 commit comments

Comments
 (0)