Skip to content

Commit ff52598

Browse files
committed
initial commit
1 parent 8b2fca0 commit ff52598

File tree

6 files changed

+238
-13
lines changed

6 files changed

+238
-13
lines changed

src/sparseml/export/export.py

+21-6
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pathlib import Path
1717
from typing import Any, List, Optional, Union
1818

19-
from sparseml.export.helpers import apply_optimizations
19+
from sparseml.export.helpers import apply_optimizations, export_sample_inputs_outputs
2020
from sparseml.exporters import ExportTargets
2121
from sparseml.integration_helper_functions import (
2222
IntegrationHelperFunctions,
@@ -42,7 +42,7 @@ def export(
4242
single_graph_file: bool = True,
4343
graph_optimizations: Union[str, List[str], None] = "all",
4444
validate_correctness: bool = False,
45-
export_sample_inputs_outputs: bool = False,
45+
num_export_samples: int = 0,
4646
deployment_directory_name: str = "deployment",
4747
device: str = "auto",
4848
):
@@ -81,8 +81,8 @@ def export(
8181
to the exported model. Defaults to 'all'.
8282
:param validate_correctness: Whether to validate the correctness
8383
of the exported model. Defaults to False.
84-
:param export_sample_inputs_outputs: Whether to export sample
85-
inputs and outputs for the exported model.Defaults to False.
84+
:param num_export_samples: The number of samples to export for
85+
the exported model. Defaults to 0.
8686
:param deployment_directory_name: The name of the deployment
8787
directory to create for the exported model. Thus, the exported
8888
model will be saved to `target_path/deployment_directory_name`.
@@ -123,8 +123,23 @@ def export(
123123
single_graph_file=single_graph_file,
124124
)
125125

126-
if export_sample_inputs_outputs:
127-
helper_functions.export_sample_inputs_outputs(model, target_path)
126+
if num_export_samples:
127+
data_loader = auxiliary_items.get("validation_loader")
128+
if data_loader is None:
129+
raise ValueError(
130+
"To export sample inputs/outputs a data loader is needed."
131+
"To enable the export, provide a `validatation_loader` "
132+
"as a part of `auxiliary_items` output of the `create_model` function."
133+
)
134+
input_samples, output_samples = helper_functions.create_sample_inputs_outputs(
135+
num_samples=num_export_samples, data_loader=data_loader
136+
)
137+
export_sample_inputs_outputs(
138+
input_samples=input_samples,
139+
output_samples=output_samples,
140+
target_path=target_path,
141+
as_tar=True,
142+
)
128143

129144
deployment_path = helper_functions.create_deployment_folder(
130145
source_path, target_path, deployment_directory_name

src/sparseml/export/helpers.py

+67-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import os
15+
import shutil
16+
import tarfile
1517
from collections import OrderedDict
1618
from enum import Enum
1719
from pathlib import Path
@@ -22,7 +24,7 @@
2224
from sparsezoo.utils.onnx import save_onnx
2325

2426

25-
__all__ = ["apply_optimizations"]
27+
__all__ = ["apply_optimizations", "export_sample_inputs_outputs"]
2628

2729

2830
class GraphOptimizationOptions(Enum):
@@ -34,6 +36,69 @@ class GraphOptimizationOptions(Enum):
3436
all = "all"
3537

3638

39+
class OutputsNames(Enum):
40+
basename = "sample-outputs"
41+
filename = "out"
42+
43+
44+
class InputsNames(Enum):
45+
basename = "sample-inputs"
46+
filename = "inp"
47+
48+
49+
def export_sample_inputs_outputs(
50+
input_samples: List["torch.Tensor"], # noqa F821
51+
output_samples: List["torch.Tensor"], # noqa F821
52+
target_path: Union[Path, str],
53+
as_tar: bool = False,
54+
):
55+
"""
56+
Save the input and output samples to the target path.
57+
58+
Input samples will be saved to:
59+
.../sample-inputs/inp_0001.npz
60+
.../sample-inputs/inp_0002.npz
61+
...
62+
63+
Output samples will be saved to:
64+
.../sample-outputs/out_0001.npz
65+
.../sample-outputs/out_0002.npz
66+
...
67+
68+
If as_tar is True, the samples will be saved as tar files:
69+
.../sample-inputs.tar.gz
70+
.../sample-outputs.tar.gz
71+
72+
:param input_samples: The input samples to save.
73+
:param output_samples: The output samples to save.
74+
:param target_path: The path to save the samples to.
75+
:param as_tar: Whether to save the samples as tar files.
76+
"""
77+
78+
from sparseml.pytorch.utils.helpers import tensors_export, tensors_to_device
79+
80+
input_samples = tensors_to_device(input_samples, "cpu")
81+
output_samples = tensors_to_device(output_samples, "cpu")
82+
83+
for tensors, names in zip(
84+
[input_samples, output_samples], [InputsNames, OutputsNames]
85+
):
86+
tensors_export(
87+
tensors=tensors,
88+
export_dir=os.path.join(target_path, names.basename.value),
89+
name_prefix=names.filename.value,
90+
)
91+
if as_tar:
92+
for folder_name_to_tar in [
93+
InputsNames.basename.value,
94+
OutputsNames.basename.value,
95+
]:
96+
folder_path = os.path.join(target_path, folder_name_to_tar)
97+
with tarfile.open(folder_path + ".tar.gz", "w:gz") as tar:
98+
tar.add(folder_path, arcname=os.path.basename(folder_path))
99+
shutil.rmtree(folder_path)
100+
101+
37102
def apply_optimizations(
38103
onnx_file_path: Union[str, Path],
39104
available_optimizations: OrderedDict[str, Callable],

src/sparseml/integration_helper_functions.py

+35-4
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414

1515
from enum import Enum
1616
from pathlib import Path
17-
from typing import Any, Callable, Dict, Optional, Tuple, Union
17+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1818

1919
from pydantic import BaseModel, Field
20+
from tqdm import tqdm
2021

2122
from sparsezoo.utils.registry import RegistryMixin
2223

@@ -32,6 +33,27 @@ class Integrations(Enum):
3233
image_classification = "image-classification"
3334

3435

36+
def create_sample_inputs_outputs(
37+
data_loader: "torch.utils.data.DataLoader", # noqa F821
38+
num_samples: int = 1,
39+
) -> Tuple[List["torch.Tensor"], List["torch.Tensor"]]: # noqa F821
40+
"""
41+
Fetch a batch of samples from the data loader and return the inputs and outputs
42+
43+
:param data_loader: The data loader to get a batch of inputs/outputs from.
44+
:param num_samples: The number of samples to generate. Defaults to 1
45+
:return: The inputs and outputs as lists of torch tensors
46+
"""
47+
inputs, outputs = [], []
48+
for batch_num, data in tqdm(enumerate(data_loader)):
49+
if batch_num == num_samples:
50+
break
51+
inputs.append(data[0])
52+
outputs.append(data[1])
53+
54+
return inputs, outputs
55+
56+
3557
class IntegrationHelperFunctions(RegistryMixin, BaseModel):
3658
"""
3759
Registry that maps names to helper functions
@@ -74,10 +96,19 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel):
7496
graph_optimizations: Optional[Dict[str, Callable]] = Field(
7597
description="A mapping from names to graph optimization functions "
7698
)
77-
export_sample_inputs_outputs: Optional[Callable] = Field(
78-
description="A function that exports input/output samples given "
79-
"a (sparse) PyTorch model."
99+
100+
create_sample_inputs_outputs: Callable[
101+
Tuple["torch.utils.data.DataLoader", int], # noqa F821
102+
Tuple[List["torch.Tensor"], List["torch.Tensor"]], # noqa F821
103+
] = Field(
104+
default=create_sample_inputs_outputs,
105+
description="A function that takes: "
106+
" - a data loader "
107+
" - the number of samples to generate "
108+
"and returns: "
109+
" - the inputs and outputs as torch tensors ",
80110
)
111+
81112
create_deployment_folder: Optional[Callable] = Field(
82113
description="A function that creates a "
83114
"deployment folder for the exporter ONNX model"

tests/sparseml/export/test_helpers.py

+55-1
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,71 @@
1313
# limitations under the License.
1414

1515
import logging
16+
import os
17+
import tarfile
1618
from collections import OrderedDict
1719

1820
import onnx
1921
import pytest
2022

21-
from src.sparseml.export.helpers import apply_optimizations
23+
from src.sparseml.export.helpers import (
24+
apply_optimizations,
25+
export_sample_inputs_outputs,
26+
)
2227
from tests.sparseml.exporters.transforms.test_onnx_transform import (
2328
_create_model as create_dummy_onnx_file,
2429
)
2530

2631

32+
@pytest.mark.parametrize(
33+
"as_tar",
34+
[True, False],
35+
)
36+
def test_export_sample_inputs_outputs(tmp_path, as_tar):
37+
pytest.importorskip("torch", reason="test requires pytorch")
38+
import torch
39+
40+
batch_size = 3
41+
num_samples = 5
42+
43+
input_samples = [torch.randn(batch_size, 3, 224, 224) for _ in range(num_samples)]
44+
output_samples = [torch.randn(batch_size, 1000) for _ in range(num_samples)]
45+
46+
export_sample_inputs_outputs(
47+
input_samples=input_samples,
48+
output_samples=output_samples,
49+
target_path=tmp_path,
50+
as_tar=as_tar,
51+
)
52+
dir_names = {"sample-inputs", "sample-outputs"}
53+
dir_names_tar = {"sample-inputs.tar.gz", "sample-outputs.tar.gz"}
54+
55+
if as_tar:
56+
assert set(os.listdir(tmp_path)) == dir_names_tar
57+
# unpack the tar files
58+
for dir_name in dir_names_tar:
59+
with tarfile.open(os.path.join(tmp_path, dir_name)) as tar:
60+
tar.extractall(path=tmp_path)
61+
62+
assert set(os.listdir(tmp_path)) == (
63+
dir_names if not as_tar else dir_names_tar | dir_names
64+
)
65+
assert set(os.listdir(os.path.join(tmp_path, "sample-inputs"))) == {
66+
"inp-0000.npz",
67+
"inp-0001.npz",
68+
"inp-0002.npz",
69+
"inp-0003.npz",
70+
"inp-0004.npz",
71+
}
72+
assert set(os.listdir(os.path.join(tmp_path, "sample-outputs"))) == {
73+
"out-0000.npz",
74+
"out-0001.npz",
75+
"out-0002.npz",
76+
"out-0003.npz",
77+
"out-0004.npz",
78+
}
79+
80+
2781
def foo(onnx_model):
2882
logging.debug("foo")
2983
return onnx_model

tests/sparseml/pytorch/image_classification/test_integration_helper_functions.py

+1
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,4 @@ def test_integration_helper_functions():
2828
assert image_classification.create_dummy_input
2929
assert image_classification.export
3030
assert image_classification.graph_optimizations is None
31+
assert image_classification.create_sample_inputs_outputs
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
from sparseml.integration_helper_functions import create_sample_inputs_outputs
18+
19+
20+
@pytest.mark.parametrize("num_samples", [0, 1, 5])
21+
def test_create_sample_inputs_outputs(num_samples):
22+
pytest.importorskip("torch", reason="test requires pytorch")
23+
import torch
24+
from torch.utils.data import DataLoader, Dataset
25+
26+
class DummmyDataset(Dataset):
27+
def __init__(self, inputs, outputs):
28+
self.data = inputs
29+
self.target = outputs
30+
31+
def __len__(self):
32+
return len(self.data)
33+
34+
def __getitem__(self, index):
35+
data_sample = self.data[index]
36+
target_sample = self.target[index]
37+
38+
return data_sample, target_sample
39+
40+
inputs = torch.randn((100, 3, 224, 224))
41+
outputs = torch.randint(
42+
0,
43+
10,
44+
(
45+
100,
46+
50,
47+
),
48+
)
49+
50+
custom_dataset = DummmyDataset(inputs, outputs)
51+
52+
data_loader = DataLoader(custom_dataset, batch_size=1)
53+
54+
inputs, outputs = create_sample_inputs_outputs(data_loader, num_samples)
55+
56+
assert all(tuple(input.shape) == (1, 3, 224, 224) for input in inputs)
57+
assert all(tuple(output.shape) == (1, 50) for output in outputs)
58+
59+
assert len(inputs) == num_samples == len(outputs)

0 commit comments

Comments
 (0)