21
21
22
22
from typing import Any , Dict , List , Optional , Type
23
23
24
+ import torch
24
25
from torch .nn import Module
25
26
from torch .optim .optimizer import Optimizer
26
27
31
32
)
32
33
from sparseml .pytorch .sparsification .quantization .helpers import (
33
34
configure_module_bn_wrappers ,
35
+ freeze_bn_stats ,
34
36
fuse_module_conv_bn_relus ,
35
37
)
36
38
from sparseml .pytorch .sparsification .quantization .legacy_modifier_quantization import (
@@ -89,6 +91,8 @@ class QuantizationModifier(ScheduledModifier):
89
91
| num_bits: 8
90
92
| symmetric: True
91
93
| exclude_module_types: ["ReLU"]
94
+ | disable_quantization_observer_epoch: 2.0
95
+ | freeze_bn_stats_epoch: 3.0
92
96
93
97
:param start_epoch: The epoch to start the modifier at
94
98
:param default_scheme: Default QuantizationScheme to use when enabling quantization
@@ -108,6 +112,11 @@ class QuantizationModifier(ScheduledModifier):
108
112
specification to quantize that module type with. Default is None
109
113
:param exclude_module_types: optional list of module class names
110
114
to not quantize. Default is None
115
+ :param disable_quantization_observer_epoch: Epoch to disable updates to the module
116
+ quantization observers. At this point, quantized weights and zero points will
117
+ not be updated. Leave None to not disable observers during QAT. Default is None
118
+ :param freeze_bn_stats_epoch: Epoch to stop the tracking of batch norm stats. Leave
119
+ None to not stop tracking batch norm stats during QAT. Default is None
111
120
:param end_epoch: Disabled, setting to anything other than -1 will raise an
112
121
exception. For compatibility with YAML serialization only.
113
122
"""
@@ -119,6 +128,8 @@ def __init__(
119
128
submodule_schemes : Optional [Dict [str , QuantizationSchemeLoadable ]] = None ,
120
129
module_type_schemes : Optional [Dict [str , QuantizationSchemeLoadable ]] = None ,
121
130
exclude_module_types : Optional [List [str ]] = None ,
131
+ disable_quantization_observer_epoch : Optional [float ] = None ,
132
+ freeze_bn_stats_epoch : Optional [float ] = None ,
122
133
end_epoch : float = - 1.0 ,
123
134
):
124
135
raise_if_torch_quantization_not_available ()
@@ -137,8 +148,12 @@ def __init__(
137
148
module_type_schemes , self ._default_scheme
138
149
)
139
150
self ._exclude_module_types = exclude_module_types
151
+ self ._disable_quantization_observer_epoch = disable_quantization_observer_epoch
152
+ self ._freeze_bn_stats_epoch = freeze_bn_stats_epoch
140
153
141
154
self ._qat_enabled = False
155
+ self ._quantization_observer_disabled = False
156
+ self ._bn_stats_frozen = False
142
157
143
158
@BaseModifier .sparsification_types .getter
144
159
def sparsification_types (self ) -> List [SparsificationTypes ]:
@@ -231,6 +246,42 @@ def exclude_module_types(self, value: Optional[List[str]]):
231
246
"""
232
247
self ._exclude_module_types = value
233
248
249
+ @ModifierProp ()
250
+ def disable_quantization_observer_epoch (self ) -> Optional [float ]:
251
+ """
252
+ :return: Epoch to disable updates to the module
253
+ quantization observers. At this point, quantized weights and zero points
254
+ will not be updated. When None, observers never disabled during QAT
255
+ """
256
+ return self ._disable_quantization_observer_epoch
257
+
258
+ @disable_quantization_observer_epoch .setter
259
+ def disable_quantization_observer_epoch (self , value : Optional [float ]):
260
+ """
261
+ :params value: Epoch to disable updates to the module
262
+ quantization observers. At this point, quantized weights and zero points
263
+ will not be updated. Set None to not disable observers during QAT
264
+ """
265
+ self ._disable_quantization_observer_epoch = value
266
+ self ._validate_params ()
267
+
268
+ @ModifierProp ()
269
+ def freeze_bn_stats_epoch (self ) -> Optional [float ]:
270
+ """
271
+ :return: Epoch to stop the tracking of batch norm stats. When
272
+ None, batch norm stats are track for all of training
273
+ """
274
+ return self ._freeze_bn_stats_epoch
275
+
276
+ @freeze_bn_stats_epoch .setter
277
+ def freeze_bn_stats_epoch (self , value : Optional [float ]):
278
+ """
279
+ :params value: Epoch to stop the tracking of batch norm stats. Set
280
+ None to not stop tracking batch norm stats during QAT
281
+ """
282
+ self ._freeze_bn_stats_epoch = value
283
+ self ._validate_params ()
284
+
234
285
def initialize (
235
286
self ,
236
287
module : Module ,
@@ -285,15 +336,61 @@ def update_ready(self, epoch: float, steps_per_epoch: int) -> bool:
285
336
return False
286
337
287
338
pending = self .start_pending (epoch , steps_per_epoch )
339
+ pending |= self ._freeze_bn_stats_update_ready (epoch )
340
+ pending |= self ._disable_quantization_observer_update_ready (epoch )
288
341
289
342
return pending
290
343
344
+ def advance_epochs (self , ref_start_epoch : float = None ):
345
+ """
346
+ Advance epoch attributes given a reference start epoch
347
+
348
+ :param ref_start_epoch: the reference, i.e. new, start epoch
349
+ """
350
+ if ref_start_epoch is None :
351
+ return
352
+
353
+ super ().advance_epochs (ref_start_epoch = ref_start_epoch )
354
+
355
+ if self ._disable_quantization_observer_epoch is not None :
356
+ self ._disable_quantization_observer_epoch = (
357
+ max (0.0 , self ._disable_quantization_observer_epoch ) + ref_start_epoch
358
+ )
359
+
360
+ if self ._freeze_bn_stats_epoch is not None :
361
+ self ._freeze_bn_stats_epoch = (
362
+ max (0.0 , self ._freeze_bn_stats_epoch ) + ref_start_epoch
363
+ )
364
+ self ._validate_params ()
365
+
291
366
def _check_quantization_update (
292
367
self , module : Module , epoch : float , steps_per_epoch : int
293
368
):
294
369
if self .start_pending (epoch , steps_per_epoch ) and not self ._qat_enabled :
295
370
self ._enable_module_qat (module )
296
371
372
+ if self ._disable_quantization_observer_update_ready (epoch ):
373
+ module .apply (torch .quantization .disable_observer )
374
+ self ._quantization_observer_disabled = True
375
+
376
+ if self ._freeze_bn_stats_update_ready (epoch ):
377
+ module .apply (freeze_bn_stats )
378
+ self ._bn_stats_frozen = True
379
+
380
+ def _disable_quantization_observer_update_ready (self , epoch : float ) -> bool :
381
+ return (
382
+ self ._disable_quantization_observer_epoch is not None
383
+ and epoch >= self ._disable_quantization_observer_epoch
384
+ and not self ._quantization_observer_disabled
385
+ )
386
+
387
+ def _freeze_bn_stats_update_ready (self , epoch : float ) -> bool :
388
+ return (
389
+ self ._freeze_bn_stats_epoch is not None
390
+ and epoch >= self ._freeze_bn_stats_epoch
391
+ and not self ._bn_stats_frozen
392
+ )
393
+
297
394
def _enable_module_qat (self , module : Module ):
298
395
# fuse conv-bn-relu blocks prior to quantization emulation
299
396
fuse_module_conv_bn_relus (module , inplace = True )
@@ -318,6 +415,30 @@ def _enable_module_qat(self, module: Module):
318
415
319
416
self ._qat_enabled = True
320
417
418
+ def _validate_params (self ):
419
+ self .validate_schedule ()
420
+ if (
421
+ self ._disable_quantization_observer_epoch is not None
422
+ and self ._disable_quantization_observer_epoch < self ._start_epoch
423
+ ):
424
+ raise ValueError (
425
+ f"disable_quantization_observer_epoch may not be greater than "
426
+ f"start_epoch for QuantizationModifier, received: "
427
+ f"{ self ._disable_quantization_observer_epoch } with start_epoch "
428
+ f"{ self ._start_epoch } "
429
+ )
430
+
431
+ if (
432
+ self ._freeze_bn_stats_epoch is not None
433
+ and self ._freeze_bn_stats_epoch < self ._start_epoch
434
+ ):
435
+ raise ValueError (
436
+ "freeze_bn_stats_epoch may not be greater than start_epoch"
437
+ " for QuantizationModifier, received: {} with start_epoch {}" .format (
438
+ self ._freeze_bn_stats_epoch , self ._start_epoch
439
+ )
440
+ )
441
+
321
442
322
443
class _QuantizationSchemesDict (dict ):
323
444
# wrapper class for dict to override the __str__ method for yaml serialization
0 commit comments