Skip to content

Commit f1ace45

Browse files
corey-nmrahul-tuli
authored and
Benjamin
committed
Adding --one-shot argument to torchvision export (#1300)
Co-authored-by: Rahul Tuli <[email protected]>
1 parent ba4184b commit f1ace45

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

src/sparseml/pytorch/torchvision/export_onnx.py

+11
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import click
2727
from sparseml.pytorch.models.registry import ModelRegistry
2828
from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET
29+
from sparseml.pytorch.optim.manager import ScheduledModifierManager
2930
from sparseml.pytorch.torchvision import presets
3031
from sparseml.pytorch.utils import ModuleExporter
3132
from sparseml.pytorch.utils.model import load_model
@@ -60,6 +61,12 @@
6061
help="The root dir path where the dataset is stored or should "
6162
"be downloaded to if available",
6263
)
64+
@click.option(
65+
"--one-shot",
66+
default=None,
67+
type=str,
68+
help="Path to recipe to use to apply in a one-shot manner",
69+
)
6370
@click.option(
6471
"--labels-to-class-mapping",
6572
type=click.Path(dir_okay=False, file_okay=True, exists=True, path_type=Path),
@@ -118,6 +125,7 @@ def main(
118125
arch_key: str,
119126
checkpoint_path: str,
120127
dataset_path: Path,
128+
one_shot: Optional[str],
121129
labels_to_class_mapping: Optional[Path],
122130
num_samples: int,
123131
onnx_opset: int,
@@ -159,6 +167,9 @@ def main(
159167

160168
load_model(checkpoint_path, model, strict=True)
161169

170+
if one_shot is not None:
171+
ScheduledModifierManager.from_yaml(one_shot).apply(model)
172+
162173
if labels_to_class_mapping is not None:
163174
with open(labels_to_class_mapping) as fp:
164175
labels_to_class_mapping = json.load(fp)

0 commit comments

Comments
 (0)