Skip to content

Commit 25808e7

Browse files
author
Benjamin
committed
[QuantizationModifier] override params for model fuse step
1 parent 0c4786a commit 25808e7

File tree

4 files changed

+114
-5
lines changed

4 files changed

+114
-5
lines changed

src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py

+86-4
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def _select_quantization_modifier(state: Dict[str, Any]) -> Type:
7171
class QuantizationModifier(ScheduledModifier):
7272
"""
7373
Enables quantization aware training (QAT) for a given module or its submodules
74-
After the start epoch, the specified module(s)' forward pass will emulate
74+
After the start epoch, the specified module(s) forward pass will emulate
7575
quantized execution and the modifier will be enabled until training is completed.
7676
7777
| Sample yaml:
@@ -99,6 +99,7 @@ class QuantizationModifier(ScheduledModifier):
9999
| exclude_module_types: ["ReLU"]
100100
| disable_quantization_observer_epoch: 2.0
101101
| freeze_bn_stats_epoch: 3.0
102+
| model_fuse_fn_name: 'fuse_module'
102103
103104
:param start_epoch: The epoch to start the modifier at
104105
:param default_scheme: Default QuantizationScheme to use when enabling quantization
@@ -123,6 +124,12 @@ class QuantizationModifier(ScheduledModifier):
123124
not be updated. Leave None to not disable observers during QAT. Default is None
124125
:param freeze_bn_stats_epoch: Epoch to stop the tracking of batch norm stats. Leave
125126
None to not stop tracking batch norm stats during QAT. Default is None
127+
:param model_fuse_fn_name: Name of model function to fuse the model in place prior
128+
to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as
129+
'conv_bv_relus' to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`.
130+
Default is None
131+
:param model_fuse_fn_kwargs: dictionary of keyword argument values to be passed
132+
to the model fusing function
126133
:param num_calibration_steps: Number of steps to run post training calibration for.
127134
When None, the entire calibration_dataloader is used
128135
:param end_epoch: Disabled, setting to anything other than -1 will raise an
@@ -138,6 +145,8 @@ def __init__(
138145
exclude_module_types: Optional[List[str]] = None,
139146
disable_quantization_observer_epoch: Optional[float] = None,
140147
freeze_bn_stats_epoch: Optional[float] = None,
148+
model_fuse_fn_name: Optional[str] = None,
149+
model_fuse_fn_kwargs: Optional[Dict[str, Any]] = None,
141150
num_calibration_steps: Optional[int] = None,
142151
end_epoch: float = -1.0,
143152
):
@@ -164,10 +173,20 @@ def __init__(
164173
self._calibration_dataloader = None
165174
self._calibration_function = None
166175

176+
self._model_fuse_fn_name = model_fuse_fn_name
177+
self._model_fuse_fn_kwargs = model_fuse_fn_kwargs or {}
178+
if (
179+
isinstance(self._model_fuse_fn_name, str)
180+
and self._model_fuse_fn_name.lower() == "none"
181+
):
182+
self._model_fuse_fn_name = None
183+
167184
self._qat_enabled = False
168185
self._quantization_observer_disabled = False
169186
self._bn_stats_frozen = False
170187

188+
self._validate_params()
189+
171190
@BaseModifier.sparsification_types.getter
172191
def sparsification_types(self) -> List[SparsificationTypes]:
173192
"""
@@ -311,6 +330,40 @@ def num_calibration_steps(self, value: Optional[int]):
311330
"""
312331
self._num_calibration_steps = value
313332

333+
@ModifierProp()
334+
def model_fuse_fn_name(self) -> Optional[str]:
335+
"""
336+
:return: Name of model function to fuse the model in place prior
337+
to performing QAT. None sets to default function.
338+
If tensorrt flag is True, default is 'no_fuse', otherwise
339+
`sparseml.pytorch.utils.fuse_module_conv_bn_relus`.
340+
"""
341+
return self._model_fuse_fn_name
342+
343+
@model_fuse_fn_name.setter
344+
def model_fuse_fn_name(self, value: Optional[str]):
345+
"""
346+
:params value: Name of model function to fuse the model in place prior
347+
to performing QAT. Set None to use the default function
348+
`sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Set as 'no_fuse'
349+
to skip module fusing.
350+
"""
351+
self._model_fuse_fn_name = value
352+
if (
353+
isinstance(self._model_fuse_fn_name, str)
354+
and self._model_fuse_fn_name.lower() == "none"
355+
):
356+
self._model_fuse_fn_name = None
357+
self._validate_params()
358+
359+
@ModifierProp()
360+
def model_fuse_fn_kwargs(self) -> Dict[str, Any]:
361+
"""
362+
:return: Dictionary of keyword arguments to be passed to the
363+
model fuse function
364+
"""
365+
return self._model_fuse_fn_kwargs
366+
314367
def initialize(
315368
self,
316369
module: Module,
@@ -434,7 +487,7 @@ def _freeze_bn_stats_update_ready(self, epoch: float) -> bool:
434487

435488
def _enable_module_qat(self, module: Module):
436489
# fuse conv-bn-relu blocks prior to quantization emulation
437-
fuse_module_conv_bn_relus(module, inplace=True)
490+
self._fuse(module)
438491

439492
# add quantization_schemes to target submodules
440493
set_quantization_schemes(
@@ -458,7 +511,22 @@ def _enable_module_qat(self, module: Module):
458511

459512
self._calibrate_if_possible(module)
460513

461-
def _calibrate_if_possible(self, module):
514+
def _fuse(self, module: Module):
515+
if self.model_fuse_fn_name in [None, "conv_bn_relus"]:
516+
self._model_fuse_fn_kwargs["inplace"] = True
517+
fuse_module_conv_bn_relus(module, **self._model_fuse_fn_kwargs)
518+
elif self.model_fuse_fn_name != "no_fuse":
519+
module_fuse_fn = getattr(module, self._model_fuse_fn_name, None)
520+
if module_fuse_fn is None or not callable(module_fuse_fn):
521+
raise ValueError(
522+
"Invalid model_fuse_fn_name. "
523+
"Module has no callable function {}".format(
524+
self._model_fuse_fn_name
525+
)
526+
)
527+
module_fuse_fn(**self._model_fuse_fn_kwargs)
528+
529+
def _calibrate_if_possible(self, module: Module):
462530
if self.num_calibration_steps == 0 and self._calibration_dataloader:
463531
warnings.warn(
464532
f"num_calibration_steps is {self.num_calibration_steps}."
@@ -477,7 +545,7 @@ def _calibrate_if_possible(self, module):
477545
elif self._calibration_dataloader:
478546
self._calibrate(module)
479547

480-
def _calibrate(self, module):
548+
def _calibrate(self, module: Module):
481549
_LOGGER.info("Running quantization calibration using calibration_dataloader")
482550

483551
module_training = module.training
@@ -530,6 +598,20 @@ def _validate_params(self):
530598
)
531599
)
532600

601+
all_schemes = [self._default_scheme]
602+
if self._submodule_schemes:
603+
all_schemes += list(self._submodule_schemes.values())
604+
if self._model_fuse_fn_kwargs:
605+
all_schemes += list(self._module_type_schemes.values())
606+
if any(scheme.target_hardware == "tensorrt" for scheme in all_schemes) and (
607+
self._model_fuse_fn_name != "no_fuse"
608+
):
609+
_LOGGER.info(
610+
"QuantizationModifier - target hardware tensorrt detected - "
611+
"Disabling model fuse step"
612+
)
613+
self._model_fuse_fn_name = "no_fuse"
614+
533615

534616
class _QuantizationSchemesDict(dict):
535617
# wrapper class for dict to override the __str__ method for yaml serialization

src/sparseml/pytorch/sparsification/quantization/quantize.py

+9
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,13 @@ def __init__(self, *args, **kwargs):
161161
"not quantize output activations. Default is None"
162162
),
163163
)
164+
target_hardware: Optional[str] = Field(
165+
default=None,
166+
description=(
167+
"target deployment runtime/hardware name to be set by default "
168+
"classmethods. Default is None"
169+
),
170+
)
164171

165172
@classmethod
166173
def load(
@@ -212,6 +219,7 @@ def deepsparse(cls) -> "QuantizationScheme":
212219
input_activations=QuantizationArgs(num_bits=8, symmetric=False),
213220
weights=QuantizationArgs(num_bits=8, symmetric=True),
214221
output_activations=None,
222+
target_hardware="deepsparse",
215223
)
216224

217225
@classmethod
@@ -225,6 +233,7 @@ def tensorrt(cls) -> "QuantizationScheme":
225233
input_activations=QuantizationArgs(num_bits=8, symmetric=True),
226234
weights=QuantizationArgs(num_bits=8, symmetric=True),
227235
output_activations=None,
236+
target_hardware="tensorrt",
228237
)
229238

230239
def get_qconfig(self) -> "torch.quantization.QConfig":

tests/sparseml/pytorch/sparsification/quantization/test_modifier_quantization.py

+16
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,8 @@ def test_quantization_modifier_yaml():
300300
disable_quantization_observer_epoch = 2.0
301301
freeze_bn_stats_epoch = 3.0
302302
num_calibration_steps = 1000
303+
model_fuse_fn_name = "custom_fuse_fn"
304+
model_fuse_fn_kwargs = dict(inplace=True)
303305

304306
yaml_str = f"""
305307
!QuantizationModifier
@@ -311,6 +313,8 @@ def test_quantization_modifier_yaml():
311313
disable_quantization_observer_epoch: {disable_quantization_observer_epoch}
312314
freeze_bn_stats_epoch: {freeze_bn_stats_epoch}
313315
num_calibration_steps: {num_calibration_steps}
316+
model_fuse_fn_name: {model_fuse_fn_name}
317+
model_fuse_fn_kwargs: {model_fuse_fn_kwargs}
314318
"""
315319
yaml_modifier = QuantizationModifier.load_obj(
316320
yaml_str
@@ -327,6 +331,8 @@ def test_quantization_modifier_yaml():
327331
disable_quantization_observer_epoch=disable_quantization_observer_epoch,
328332
freeze_bn_stats_epoch=freeze_bn_stats_epoch,
329333
num_calibration_steps=num_calibration_steps,
334+
model_fuse_fn_name=model_fuse_fn_name,
335+
model_fuse_fn_kwargs=model_fuse_fn_kwargs,
330336
)
331337

332338
assert isinstance(yaml_modifier, QuantizationModifier)
@@ -373,3 +379,13 @@ def test_quantization_modifier_yaml():
373379
== serialized_modifier.num_calibration_steps
374380
== obj_modifier.num_calibration_steps
375381
)
382+
assert (
383+
yaml_modifier.model_fuse_fn_name
384+
== serialized_modifier.model_fuse_fn_name
385+
== obj_modifier.model_fuse_fn_name
386+
)
387+
assert (
388+
yaml_modifier.model_fuse_fn_kwargs
389+
== serialized_modifier.model_fuse_fn_kwargs
390+
== obj_modifier.model_fuse_fn_kwargs
391+
)

tests/sparseml/pytorch/sparsification/quantization/test_quantize.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def test_quantization_args_get_observer(
119119
input_activations=QuantizationArgs(num_bits=8, symmetric=False),
120120
weights=QuantizationArgs(num_bits=8, symmetric=True),
121121
output_activations=None,
122+
target_hardware="deepsparse",
122123
),
123124
),
124125
(
@@ -127,10 +128,11 @@ def test_quantization_args_get_observer(
127128
input_activations=QuantizationArgs(num_bits=8, symmetric=True),
128129
weights=QuantizationArgs(num_bits=8, symmetric=True),
129130
output_activations=None,
131+
target_hardware="tensorrt",
130132
),
131133
),
132134
# adding to raise an issue if default scheme changes from deepsparse
133-
("deepsparse", QuantizationScheme()),
135+
("deepsparse", QuantizationScheme(target_hardware="deepsparse")),
134136
],
135137
)
136138
def test_load_quantization_scheme_from_str(scheme_str, expected_scheme):

0 commit comments

Comments
 (0)