|
26 | 26 | import click
|
27 | 27 | from sparseml.pytorch.models.registry import ModelRegistry
|
28 | 28 | from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET
|
| 29 | +from sparseml.pytorch.optim.manager import ScheduledModifierManager |
29 | 30 | from sparseml.pytorch.torchvision import presets
|
30 | 31 | from sparseml.pytorch.utils import ModuleExporter
|
31 | 32 | from sparseml.pytorch.utils.model import load_model
|
|
60 | 61 | help="The root dir path where the dataset is stored or should "
|
61 | 62 | "be downloaded to if available",
|
62 | 63 | )
|
| 64 | +@click.option( |
| 65 | + "--one-shot", |
| 66 | + default=None, |
| 67 | + type=str, |
| 68 | + help="Path to recipe to use to apply in a one-shot manner", |
| 69 | +) |
63 | 70 | @click.option(
|
64 | 71 | "--labels-to-class-mapping",
|
65 | 72 | type=click.Path(dir_okay=False, file_okay=True, exists=True, path_type=Path),
|
@@ -118,6 +125,7 @@ def main(
|
118 | 125 | arch_key: str,
|
119 | 126 | checkpoint_path: str,
|
120 | 127 | dataset_path: Path,
|
| 128 | + one_shot: Optional[str], |
121 | 129 | labels_to_class_mapping: Optional[Path],
|
122 | 130 | num_samples: int,
|
123 | 131 | onnx_opset: int,
|
@@ -159,6 +167,9 @@ def main(
|
159 | 167 |
|
160 | 168 | load_model(checkpoint_path, model, strict=True)
|
161 | 169 |
|
| 170 | + if one_shot is not None: |
| 171 | + ScheduledModifierManager.from_yaml(one_shot).apply(model) |
| 172 | + |
162 | 173 | if labels_to_class_mapping is not None:
|
163 | 174 | with open(labels_to_class_mapping) as fp:
|
164 | 175 | labels_to_class_mapping = json.load(fp)
|
|
0 commit comments