11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
-
14
+ import os
15
+ import shutil
16
+ import tarfile
15
17
from collections import OrderedDict
16
18
from enum import Enum
17
19
from pathlib import Path
22
24
from sparsezoo .utils .onnx import save_onnx
23
25
24
26
25
- __all__ = ["apply_optimizations" ]
27
+ __all__ = ["apply_optimizations" , "export_sample_inputs_outputs" ]
26
28
27
29
28
30
class GraphOptimizationOptions (Enum ):
@@ -34,6 +36,69 @@ class GraphOptimizationOptions(Enum):
34
36
all = "all"
35
37
36
38
39
+ class OutputsNames (Enum ):
40
+ basename = "sample-outputs"
41
+ filename = "out"
42
+
43
+
44
+ class InputsNames (Enum ):
45
+ basename = "sample-inputs"
46
+ filename = "inp"
47
+
48
+
49
+ def export_sample_inputs_outputs (
50
+ input_samples : List ["torch.Tensor" ], # noqa F821
51
+ output_samples : List ["torch.Tensor" ], # noqa F821
52
+ target_path : Union [Path , str ],
53
+ as_tar : bool = False ,
54
+ ):
55
+ """
56
+ Save the input and output samples to the target path.
57
+
58
+ Input samples will be saved to:
59
+ .../sample-inputs/inp_0001.npz
60
+ .../sample-inputs/inp_0002.npz
61
+ ...
62
+
63
+ Output samples will be saved to:
64
+ .../sample-outputs/out_0001.npz
65
+ .../sample-outputs/out_0002.npz
66
+ ...
67
+
68
+ If as_tar is True, the samples will be saved as tar files:
69
+ .../sample-inputs.tar.gz
70
+ .../sample-outputs.tar.gz
71
+
72
+ :param input_samples: The input samples to save.
73
+ :param output_samples: The output samples to save.
74
+ :param target_path: The path to save the samples to.
75
+ :param as_tar: Whether to save the samples as tar files.
76
+ """
77
+
78
+ from sparseml .pytorch .utils .helpers import tensors_export , tensors_to_device
79
+
80
+ input_samples = tensors_to_device (input_samples , "cpu" )
81
+ output_samples = tensors_to_device (output_samples , "cpu" )
82
+
83
+ for tensors , names in zip (
84
+ [input_samples , output_samples ], [InputsNames , OutputsNames ]
85
+ ):
86
+ tensors_export (
87
+ tensors = tensors ,
88
+ export_dir = os .path .join (target_path , names .basename .value ),
89
+ name_prefix = names .filename .value ,
90
+ )
91
+ if as_tar :
92
+ for folder_name_to_tar in [
93
+ InputsNames .basename .value ,
94
+ OutputsNames .basename .value ,
95
+ ]:
96
+ folder_path = os .path .join (target_path , folder_name_to_tar )
97
+ with tarfile .open (folder_path + ".tar.gz" , "w:gz" ) as tar :
98
+ tar .add (folder_path , arcname = os .path .basename (folder_path ))
99
+ shutil .rmtree (folder_path )
100
+
101
+
37
102
def apply_optimizations (
38
103
onnx_file_path : Union [str , Path ],
39
104
available_optimizations : OrderedDict [str , Callable ],
0 commit comments