Skip to content

Commit 3ae10f2

Browse files
author
Benjamin
committed
[QuantizationModifier] freeze bn stats and disable observers for QAT finetuning support
1 parent 07e7446 commit 3ae10f2

File tree

2 files changed

+167
-3
lines changed

2 files changed

+167
-3
lines changed

src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py

+121
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from typing import Any, Dict, List, Optional, Type
2323

24+
import torch
2425
from torch.nn import Module
2526
from torch.optim.optimizer import Optimizer
2627

@@ -31,6 +32,7 @@
3132
)
3233
from sparseml.pytorch.sparsification.quantization.helpers import (
3334
configure_module_bn_wrappers,
35+
freeze_bn_stats,
3436
fuse_module_conv_bn_relus,
3537
)
3638
from sparseml.pytorch.sparsification.quantization.legacy_modifier_quantization import (
@@ -89,6 +91,8 @@ class QuantizationModifier(ScheduledModifier):
8991
| num_bits: 8
9092
| symmetric: True
9193
| exclude_module_types: ["ReLU"]
94+
| disable_quantization_observer_epoch: 2.0
95+
| freeze_bn_stats_epoch: 3.0
9296
9397
:param start_epoch: The epoch to start the modifier at
9498
:param default_scheme: Default QuantizationScheme to use when enabling quantization
@@ -108,6 +112,11 @@ class QuantizationModifier(ScheduledModifier):
108112
specification to quantize that module type with. Default is None
109113
:param exclude_module_types: optional list of module class names
110114
to not quantize. Default is None
115+
:param disable_quantization_observer_epoch: Epoch to disable updates to the module
116+
quantization observers. At this point, quantized weights and zero points will
117+
not be updated. Leave None to not disable observers during QAT. Default is None
118+
:param freeze_bn_stats_epoch: Epoch to stop the tracking of batch norm stats. Leave
119+
None to not stop tracking batch norm stats during QAT. Default is None
111120
:param end_epoch: Disabled, setting to anything other than -1 will raise an
112121
exception. For compatibility with YAML serialization only.
113122
"""
@@ -119,6 +128,8 @@ def __init__(
119128
submodule_schemes: Optional[Dict[str, QuantizationSchemeLoadable]] = None,
120129
module_type_schemes: Optional[Dict[str, QuantizationSchemeLoadable]] = None,
121130
exclude_module_types: Optional[List[str]] = None,
131+
disable_quantization_observer_epoch: Optional[float] = None,
132+
freeze_bn_stats_epoch: Optional[float] = None,
122133
end_epoch: float = -1.0,
123134
):
124135
raise_if_torch_quantization_not_available()
@@ -137,8 +148,12 @@ def __init__(
137148
module_type_schemes, self._default_scheme
138149
)
139150
self._exclude_module_types = exclude_module_types
151+
self._disable_quantization_observer_epoch = disable_quantization_observer_epoch
152+
self._freeze_bn_stats_epoch = freeze_bn_stats_epoch
140153

141154
self._qat_enabled = False
155+
self._quantization_observer_disabled = False
156+
self._bn_stats_frozen = False
142157

143158
@BaseModifier.sparsification_types.getter
144159
def sparsification_types(self) -> List[SparsificationTypes]:
@@ -231,6 +246,42 @@ def exclude_module_types(self, value: Optional[List[str]]):
231246
"""
232247
self._exclude_module_types = value
233248

249+
@ModifierProp()
250+
def disable_quantization_observer_epoch(self) -> Optional[float]:
251+
"""
252+
:return: Epoch to disable updates to the module
253+
quantization observers. At this point, quantized weights and zero points
254+
will not be updated. When None, observers never disabled during QAT
255+
"""
256+
return self._disable_quantization_observer_epoch
257+
258+
@disable_quantization_observer_epoch.setter
259+
def disable_quantization_observer_epoch(self, value: Optional[float]):
260+
"""
261+
:params value: Epoch to disable updates to the module
262+
quantization observers. At this point, quantized weights and zero points
263+
will not be updated. Set None to not disable observers during QAT
264+
"""
265+
self._disable_quantization_observer_epoch = value
266+
self._validate_params()
267+
268+
@ModifierProp()
269+
def freeze_bn_stats_epoch(self) -> Optional[float]:
270+
"""
271+
:return: Epoch to stop the tracking of batch norm stats. When
272+
None, batch norm stats are track for all of training
273+
"""
274+
return self._freeze_bn_stats_epoch
275+
276+
@freeze_bn_stats_epoch.setter
277+
def freeze_bn_stats_epoch(self, value: Optional[float]):
278+
"""
279+
:params value: Epoch to stop the tracking of batch norm stats. Set
280+
None to not stop tracking batch norm stats during QAT
281+
"""
282+
self._freeze_bn_stats_epoch = value
283+
self._validate_params()
284+
234285
def initialize(
235286
self,
236287
module: Module,
@@ -285,15 +336,61 @@ def update_ready(self, epoch: float, steps_per_epoch: int) -> bool:
285336
return False
286337

287338
pending = self.start_pending(epoch, steps_per_epoch)
339+
pending |= self._freeze_bn_stats_update_ready(epoch)
340+
pending |= self._disable_quantization_observer_update_ready(epoch)
288341

289342
return pending
290343

344+
def advance_epochs(self, ref_start_epoch: float = None):
345+
"""
346+
Advance epoch attributes given a reference start epoch
347+
348+
:param ref_start_epoch: the reference, i.e. new, start epoch
349+
"""
350+
if ref_start_epoch is None:
351+
return
352+
353+
super().advance_epochs(ref_start_epoch=ref_start_epoch)
354+
355+
if self._disable_quantization_observer_epoch is not None:
356+
self._disable_quantization_observer_epoch = (
357+
max(0.0, self._disable_quantization_observer_epoch) + ref_start_epoch
358+
)
359+
360+
if self._freeze_bn_stats_epoch is not None:
361+
self._freeze_bn_stats_epoch = (
362+
max(0.0, self._freeze_bn_stats_epoch) + ref_start_epoch
363+
)
364+
self._validate_params()
365+
291366
def _check_quantization_update(
292367
self, module: Module, epoch: float, steps_per_epoch: int
293368
):
294369
if self.start_pending(epoch, steps_per_epoch) and not self._qat_enabled:
295370
self._enable_module_qat(module)
296371

372+
if self._disable_quantization_observer_update_ready(epoch):
373+
module.apply(torch.quantization.disable_observer)
374+
self._quantization_observer_disabled = True
375+
376+
if self._freeze_bn_stats_update_ready(epoch):
377+
module.apply(freeze_bn_stats)
378+
self._bn_stats_frozen = True
379+
380+
def _disable_quantization_observer_update_ready(self, epoch: float) -> bool:
381+
return (
382+
self._disable_quantization_observer_epoch is not None
383+
and epoch >= self._disable_quantization_observer_epoch
384+
and not self._quantization_observer_disabled
385+
)
386+
387+
def _freeze_bn_stats_update_ready(self, epoch: float) -> bool:
388+
return (
389+
self._freeze_bn_stats_epoch is not None
390+
and epoch >= self._freeze_bn_stats_epoch
391+
and not self._bn_stats_frozen
392+
)
393+
297394
def _enable_module_qat(self, module: Module):
298395
# fuse conv-bn-relu blocks prior to quantization emulation
299396
fuse_module_conv_bn_relus(module, inplace=True)
@@ -318,6 +415,30 @@ def _enable_module_qat(self, module: Module):
318415

319416
self._qat_enabled = True
320417

418+
def _validate_params(self):
419+
self.validate_schedule()
420+
if (
421+
self._disable_quantization_observer_epoch is not None
422+
and self._disable_quantization_observer_epoch < self._start_epoch
423+
):
424+
raise ValueError(
425+
f"disable_quantization_observer_epoch may not be greater than "
426+
f"start_epoch for QuantizationModifier, received: "
427+
f"{self._disable_quantization_observer_epoch} with start_epoch "
428+
f"{self._start_epoch}"
429+
)
430+
431+
if (
432+
self._freeze_bn_stats_epoch is not None
433+
and self._freeze_bn_stats_epoch < self._start_epoch
434+
):
435+
raise ValueError(
436+
"freeze_bn_stats_epoch may not be greater than start_epoch"
437+
" for QuantizationModifier, received: {} with start_epoch {}".format(
438+
self._freeze_bn_stats_epoch, self._start_epoch
439+
)
440+
)
441+
321442

322443
class _QuantizationSchemesDict(dict):
323444
# wrapper class for dict to override the __str__ method for yaml serialization

tests/sparseml/pytorch/sparsification/quantization/test_modifier_quantization.py

+46-3
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,21 @@ def _test_qat_applied(modifier, model):
122122
assert not hasattr(module, "qconfig")
123123

124124

125+
def _test_freeze_bn_stats_observer_applied(modifier, epoch):
126+
if modifier.disable_quantization_observer_epoch is not None and (
127+
epoch >= modifier.disable_quantization_observer_epoch
128+
):
129+
assert modifier._quantization_observer_disabled
130+
else:
131+
assert not modifier._quantization_observer_disabled
132+
if modifier.freeze_bn_stats_epoch is not None and (
133+
epoch >= modifier.freeze_bn_stats_epoch
134+
):
135+
assert modifier._bn_stats_frozen
136+
else:
137+
assert not modifier._bn_stats_frozen
138+
139+
125140
@pytest.mark.skipif(
126141
os.getenv("NM_ML_SKIP_PYTORCH_TESTS", False),
127142
reason="Skipping pytorch tests",
@@ -148,6 +163,8 @@ def _test_qat_applied(modifier, model):
148163
lambda: QuantizationModifier(
149164
start_epoch=0.0,
150165
submodule_schemes=dict(seq="default"),
166+
freeze_bn_stats_epoch=2.0,
167+
disable_quantization_observer_epoch=3.0,
151168
),
152169
LinearNet,
153170
),
@@ -175,6 +192,8 @@ def _test_qat_applied(modifier, model):
175192
lambda: QuantizationModifier(
176193
start_epoch=2.0,
177194
module_type_schemes=dict(Conv2d=QuantizationScheme(weights=None)),
195+
freeze_bn_stats_epoch=2.5,
196+
disable_quantization_observer_epoch=2.2,
178197
),
179198
ConvNet,
180199
),
@@ -206,10 +225,16 @@ def test_lifecycle(
206225

207226
self.initialize_helper(modifier, model)
208227

228+
_test_freeze_bn_stats_observer_applied(modifier, 0.0)
209229
for epoch in range(int(modifier.start_epoch)):
210230
assert not modifier.update_ready(epoch, test_steps_per_epoch)
211231

212232
update_epochs = [modifier.start_epoch]
233+
if modifier.disable_quantization_observer_epoch is not None:
234+
update_epochs.append(modifier.disable_quantization_observer_epoch)
235+
if modifier.freeze_bn_stats_epoch is not None:
236+
update_epochs.append(modifier.freeze_bn_stats_epoch)
237+
update_epochs.sort()
213238
for epoch in update_epochs:
214239
assert modifier.update_ready(epoch, test_steps_per_epoch)
215240
# test update ready is still true after start epoch
@@ -225,9 +250,11 @@ def test_lifecycle(
225250
_test_qat_applied(modifier, model)
226251
pass
227252

228-
modifier.scheduled_update(
229-
model, optimizer, modifier.start_epoch, test_steps_per_epoch
230-
)
253+
for update_epoch in update_epochs:
254+
modifier.scheduled_update(
255+
model, optimizer, update_epoch, test_steps_per_epoch
256+
)
257+
_test_freeze_bn_stats_observer_applied(modifier, update_epoch)
231258

232259
# test update ready is False after start epoch is applied, before disable epochs
233260
if (
@@ -270,6 +297,8 @@ def test_quantization_modifier_yaml():
270297
)
271298
module_type_schemes = dict(Linear=dict(output_activations=dict(symmetric=False)))
272299
exclude_module_types = ["LayerNorm", "Tanh"]
300+
disable_quantization_observer_epoch = 2.0
301+
freeze_bn_stats_epoch = 3.0
273302

274303
yaml_str = f"""
275304
!QuantizationModifier
@@ -278,6 +307,8 @@ def test_quantization_modifier_yaml():
278307
submodule_schemes: {submodule_schemes}
279308
module_type_schemes: {module_type_schemes}
280309
exclude_module_types: {exclude_module_types}
310+
disable_quantization_observer_epoch: {disable_quantization_observer_epoch}
311+
freeze_bn_stats_epoch: {freeze_bn_stats_epoch}
281312
"""
282313
yaml_modifier = QuantizationModifier.load_obj(
283314
yaml_str
@@ -291,6 +322,8 @@ def test_quantization_modifier_yaml():
291322
submodule_schemes=submodule_schemes,
292323
module_type_schemes=module_type_schemes,
293324
exclude_module_types=exclude_module_types,
325+
disable_quantization_observer_epoch=disable_quantization_observer_epoch,
326+
freeze_bn_stats_epoch=freeze_bn_stats_epoch,
294327
)
295328

296329
assert isinstance(yaml_modifier, QuantizationModifier)
@@ -322,3 +355,13 @@ def test_quantization_modifier_yaml():
322355
== serialized_modifier.exclude_module_types
323356
== obj_modifier.exclude_module_types
324357
)
358+
assert (
359+
yaml_modifier.disable_quantization_observer_epoch
360+
== serialized_modifier.disable_quantization_observer_epoch
361+
== obj_modifier.disable_quantization_observer_epoch
362+
)
363+
assert (
364+
yaml_modifier.freeze_bn_stats_epoch
365+
== serialized_modifier.freeze_bn_stats_epoch
366+
== obj_modifier.freeze_bn_stats_epoch
367+
)

0 commit comments

Comments
 (0)