Skip to content

Commit c2dfe65

Browse files
bfineranBenjamin
authored and
Benjamin
committed
[QuantizationModifier] refactor base - move deprecated code to legacy file, add object routing for yaml load (#1059)
* move existing ModifierQuantization and tests to legacy file * [QuantizationModifier] refactor base - move deprecated code to legacy file, add object routing for yaml load
1 parent 6cf2b08 commit c2dfe65

File tree

6 files changed

+934
-809
lines changed

6 files changed

+934
-809
lines changed

src/sparseml/optim/modifier.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import hashlib
2121
import re
2222
from abc import ABC, abstractmethod
23-
from typing import Any, Callable, Dict, List, Union
23+
from typing import Any, Callable, Dict, List, Type, Union
2424

2525
import yaml
2626
from yaml import ScalarNode
@@ -822,13 +822,22 @@ class ModifierYAML(object):
822822
823823
:param framework: the string representing the ML framework the modifier should
824824
be stored under
825+
:param swap_class_by_state_fn: optional function to provide a different class
826+
to construct on yaml load based on the state given (ie provide a
827+
legacy class to load if certain parameters are passed). Expected format
828+
is to take a dict of kwargs, expects a class to be returned
825829
"""
826830

827-
def __init__(self, framework: str):
831+
def __init__(
832+
self,
833+
framework: str,
834+
swap_class_by_state_fn: Callable[[Dict[str, Any]], Type[BaseModifier]] = None,
835+
):
828836
if not framework:
829837
raise ValueError("a framework is required")
830838

831839
self._framework = framework
840+
self._swap_class_by_state_fn = swap_class_by_state_fn
832841

833842
def __call__(self, clazz):
834843
"""
@@ -838,13 +847,18 @@ def __call__(self, clazz):
838847
yaml_key = "{}".format(BaseModifier.yaml_key(clazz, self._framework))
839848

840849
def constructor(loader, node):
841-
instance = clazz.__new__(clazz)
842-
yield instance
843850
state = (
844851
loader.construct_mapping(node, deep=True)
845852
if not isinstance(node, ScalarNode)
846853
else {}
847854
)
855+
target_class = (
856+
self._swap_class_by_state_fn(state)
857+
if self._swap_class_by_state_fn is not None
858+
else clazz
859+
)
860+
instance = target_class.__new__(target_class)
861+
yield instance
848862
# ignore the log_types arg in recipes to maintain backwards compatability
849863
# while recipes are updated
850864
if "log_types" in state:

src/sparseml/pytorch/sparsification/modifier.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
"""
2121

2222
import math
23-
from typing import Dict, Iterable, List, Optional, Tuple, Union
23+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
2424

2525
from torch import Tensor
2626
from torch.nn import Module
@@ -52,10 +52,21 @@ class PyTorchModifierYAML(ModifierYAML):
5252
"""
5353
A decorator to handle making a pytorch modifier class YAML ready.
5454
IE it can be loaded in through the yaml plugin easily.
55+
56+
:param swap_class_by_state_fn: optional function to provide a different class
57+
to construct on yaml load based on the state given (ie provide a
58+
legacy class to load if certain parameters are passed). Expected format
59+
is to take a dict of kwargs, expects a class to be returned
5560
"""
5661

57-
def __init__(self):
58-
super().__init__(PYTORCH_FRAMEWORK)
62+
def __init__(
63+
self,
64+
swap_class_by_state_fn: Callable[[Dict[str, Any]], Type[BaseModifier]] = None,
65+
):
66+
super().__init__(
67+
PYTORCH_FRAMEWORK,
68+
swap_class_by_state_fn=swap_class_by_state_fn,
69+
)
5970

6071

6172
class Modifier(BaseModifier):

0 commit comments

Comments
 (0)