Skip to content

[QuantizationModifier] override params for model fuse step #1209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _select_quantization_modifier(state: Dict[str, Any]) -> Type:
class QuantizationModifier(ScheduledModifier):
"""
Enables quantization aware training (QAT) for a given module or its submodules
After the start epoch, the specified module(s)' forward pass will emulate
After the start epoch, the specified module(s) forward pass will emulate
quantized execution and the modifier will be enabled until training is completed.

| Sample yaml:
Expand Down Expand Up @@ -99,6 +99,7 @@ class QuantizationModifier(ScheduledModifier):
| exclude_module_types: ["ReLU"]
| disable_quantization_observer_epoch: 2.0
| freeze_bn_stats_epoch: 3.0
| model_fuse_fn_name: 'fuse_module'

:param start_epoch: The epoch to start the modifier at
:param default_scheme: Default QuantizationScheme to use when enabling quantization
Expand All @@ -123,6 +124,12 @@ class QuantizationModifier(ScheduledModifier):
not be updated. Leave None to not disable observers during QAT. Default is None
:param freeze_bn_stats_epoch: Epoch to stop the tracking of batch norm stats. Leave
None to not stop tracking batch norm stats during QAT. Default is None
:param model_fuse_fn_name: Name of model function to fuse the model in place prior
to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as
'conv_bv_relus' to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`.
Default is None
:param model_fuse_fn_kwargs: dictionary of keyword argument values to be passed
to the model fusing function
:param num_calibration_steps: Number of steps to run post training calibration for.
When None, the entire calibration_dataloader is used
:param end_epoch: Disabled, setting to anything other than -1 will raise an
Expand All @@ -138,6 +145,8 @@ def __init__(
exclude_module_types: Optional[List[str]] = None,
disable_quantization_observer_epoch: Optional[float] = None,
freeze_bn_stats_epoch: Optional[float] = None,
model_fuse_fn_name: Optional[str] = None,
model_fuse_fn_kwargs: Optional[Dict[str, Any]] = None,
num_calibration_steps: Optional[int] = None,
end_epoch: float = -1.0,
):
Expand All @@ -164,10 +173,20 @@ def __init__(
self._calibration_dataloader = None
self._calibration_function = None

self._model_fuse_fn_name = model_fuse_fn_name
self._model_fuse_fn_kwargs = model_fuse_fn_kwargs or {}
if (
isinstance(self._model_fuse_fn_name, str)
and self._model_fuse_fn_name.lower() == "none"
):
self._model_fuse_fn_name = None

self._qat_enabled = False
self._quantization_observer_disabled = False
self._bn_stats_frozen = False

self._validate_params()

@BaseModifier.sparsification_types.getter
def sparsification_types(self) -> List[SparsificationTypes]:
"""
Expand Down Expand Up @@ -311,6 +330,40 @@ def num_calibration_steps(self, value: Optional[int]):
"""
self._num_calibration_steps = value

@ModifierProp()
def model_fuse_fn_name(self) -> Optional[str]:
"""
:return: Name of model function to fuse the model in place prior
to performing QAT. None sets to default function.
If tensorrt flag is True, default is 'no_fuse', otherwise
`sparseml.pytorch.utils.fuse_module_conv_bn_relus`.
"""
return self._model_fuse_fn_name

@model_fuse_fn_name.setter
def model_fuse_fn_name(self, value: Optional[str]):
"""
:params value: Name of model function to fuse the model in place prior
to performing QAT. Set None to use the default function
`sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Set as 'no_fuse'
to skip module fusing.
"""
self._model_fuse_fn_name = value
if (
isinstance(self._model_fuse_fn_name, str)
and self._model_fuse_fn_name.lower() == "none"
):
self._model_fuse_fn_name = None
self._validate_params()

@ModifierProp()
def model_fuse_fn_kwargs(self) -> Dict[str, Any]:
"""
:return: Dictionary of keyword arguments to be passed to the
model fuse function
"""
return self._model_fuse_fn_kwargs

def initialize(
self,
module: Module,
Expand Down Expand Up @@ -434,7 +487,7 @@ def _freeze_bn_stats_update_ready(self, epoch: float) -> bool:

def _enable_module_qat(self, module: Module):
# fuse conv-bn-relu blocks prior to quantization emulation
fuse_module_conv_bn_relus(module, inplace=True)
self._fuse(module)

# add quantization_schemes to target submodules
set_quantization_schemes(
Expand All @@ -458,7 +511,22 @@ def _enable_module_qat(self, module: Module):

self._calibrate_if_possible(module)

def _calibrate_if_possible(self, module):
def _fuse(self, module: Module):
if self.model_fuse_fn_name in [None, "conv_bn_relus"]:
self._model_fuse_fn_kwargs["inplace"] = True
fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs)
elif self.model_fuse_fn_name != "no_fuse":
module_fuse_fn = getattr(module, self._model_fuse_fn_name, None)
if module_fuse_fn is None or not callable(module_fuse_fn):
raise ValueError(
"Invalid model_fuse_fn_name. "
"Module has no callable function {}".format(
self._model_fuse_fn_name
)
)
module_fuse_fn(**self._model_fuse_fn_kwargs)

def _calibrate_if_possible(self, module: Module):
if self.num_calibration_steps == 0 and self._calibration_dataloader:
warnings.warn(
f"num_calibration_steps is {self.num_calibration_steps}."
Expand All @@ -477,7 +545,7 @@ def _calibrate_if_possible(self, module):
elif self._calibration_dataloader:
self._calibrate(module)

def _calibrate(self, module):
def _calibrate(self, module: Module):
_LOGGER.info("Running quantization calibration using calibration_dataloader")

module_training = module.training
Expand Down Expand Up @@ -530,6 +598,20 @@ def _validate_params(self):
)
)

all_schemes = [self._default_scheme]
if self._submodule_schemes:
all_schemes += list(self._submodule_schemes.values())
if self._model_fuse_fn_kwargs:
all_schemes += list(self._module_type_schemes.values())
if any(scheme.target_hardware == "tensorrt" for scheme in all_schemes) and (
self._model_fuse_fn_name != "no_fuse"
):
_LOGGER.info(
"QuantizationModifier - target hardware tensorrt detected - "
"Disabling model fuse step"
)
self._model_fuse_fn_name = "no_fuse"


class _QuantizationSchemesDict(dict):
# wrapper class for dict to override the __str__ method for yaml serialization
Expand Down
9 changes: 9 additions & 0 deletions src/sparseml/pytorch/sparsification/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ def __init__(self, *args, **kwargs):
"not quantize output activations. Default is None"
),
)
target_hardware: Optional[str] = Field(
default=None,
description=(
"target deployment runtime/hardware name to be set by default "
"classmethods. Default is None"
),
)

@classmethod
def load(
Expand Down Expand Up @@ -212,6 +219,7 @@ def deepsparse(cls) -> "QuantizationScheme":
input_activations=QuantizationArgs(num_bits=8, symmetric=False),
weights=QuantizationArgs(num_bits=8, symmetric=True),
output_activations=None,
target_hardware="deepsparse",
)

@classmethod
Expand All @@ -225,6 +233,7 @@ def tensorrt(cls) -> "QuantizationScheme":
input_activations=QuantizationArgs(num_bits=8, symmetric=True),
weights=QuantizationArgs(num_bits=8, symmetric=True),
output_activations=None,
target_hardware="tensorrt",
)

def get_qconfig(self) -> "torch.quantization.QConfig":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ def test_quantization_modifier_yaml():
disable_quantization_observer_epoch = 2.0
freeze_bn_stats_epoch = 3.0
num_calibration_steps = 1000
model_fuse_fn_name = "custom_fuse_fn"
model_fuse_fn_kwargs = dict(inplace=True)

yaml_str = f"""
!QuantizationModifier
Expand All @@ -311,6 +313,8 @@ def test_quantization_modifier_yaml():
disable_quantization_observer_epoch: {disable_quantization_observer_epoch}
freeze_bn_stats_epoch: {freeze_bn_stats_epoch}
num_calibration_steps: {num_calibration_steps}
model_fuse_fn_name: {model_fuse_fn_name}
model_fuse_fn_kwargs: {model_fuse_fn_kwargs}
"""
yaml_modifier = QuantizationModifier.load_obj(
yaml_str
Expand All @@ -327,6 +331,8 @@ def test_quantization_modifier_yaml():
disable_quantization_observer_epoch=disable_quantization_observer_epoch,
freeze_bn_stats_epoch=freeze_bn_stats_epoch,
num_calibration_steps=num_calibration_steps,
model_fuse_fn_name=model_fuse_fn_name,
model_fuse_fn_kwargs=model_fuse_fn_kwargs,
)

assert isinstance(yaml_modifier, QuantizationModifier)
Expand Down Expand Up @@ -373,3 +379,13 @@ def test_quantization_modifier_yaml():
== serialized_modifier.num_calibration_steps
== obj_modifier.num_calibration_steps
)
assert (
yaml_modifier.model_fuse_fn_name
== serialized_modifier.model_fuse_fn_name
== obj_modifier.model_fuse_fn_name
)
assert (
yaml_modifier.model_fuse_fn_kwargs
== serialized_modifier.model_fuse_fn_kwargs
== obj_modifier.model_fuse_fn_kwargs
)
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def test_quantization_args_get_observer(
input_activations=QuantizationArgs(num_bits=8, symmetric=False),
weights=QuantizationArgs(num_bits=8, symmetric=True),
output_activations=None,
target_hardware="deepsparse",
),
),
(
Expand All @@ -127,10 +128,11 @@ def test_quantization_args_get_observer(
input_activations=QuantizationArgs(num_bits=8, symmetric=True),
weights=QuantizationArgs(num_bits=8, symmetric=True),
output_activations=None,
target_hardware="tensorrt",
),
),
# adding to raise an issue if default scheme changes from deepsparse
("deepsparse", QuantizationScheme()),
("deepsparse", QuantizationScheme(target_hardware="deepsparse")),
],
)
def test_load_quantization_scheme_from_str(scheme_str, expected_scheme):
Expand Down