@@ -71,7 +71,7 @@ def _select_quantization_modifier(state: Dict[str, Any]) -> Type:
71
71
class QuantizationModifier (ScheduledModifier ):
72
72
"""
73
73
Enables quantization aware training (QAT) for a given module or its submodules
74
- After the start epoch, the specified module(s)' forward pass will emulate
74
+ After the start epoch, the specified module(s) forward pass will emulate
75
75
quantized execution and the modifier will be enabled until training is completed.
76
76
77
77
| Sample yaml:
@@ -99,6 +99,7 @@ class QuantizationModifier(ScheduledModifier):
99
99
| exclude_module_types: ["ReLU"]
100
100
| disable_quantization_observer_epoch: 2.0
101
101
| freeze_bn_stats_epoch: 3.0
102
+ | model_fuse_fn_name: 'fuse_module'
102
103
103
104
:param start_epoch: The epoch to start the modifier at
104
105
:param default_scheme: Default QuantizationScheme to use when enabling quantization
@@ -123,6 +124,12 @@ class QuantizationModifier(ScheduledModifier):
123
124
not be updated. Leave None to not disable observers during QAT. Default is None
124
125
:param freeze_bn_stats_epoch: Epoch to stop the tracking of batch norm stats. Leave
125
126
None to not stop tracking batch norm stats during QAT. Default is None
127
+ :param model_fuse_fn_name: Name of model function to fuse the model in place prior
128
+ to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as
129
+ 'conv_bv_relus' to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`.
130
+ Default is None
131
+ :param model_fuse_fn_kwargs: dictionary of keyword argument values to be passed
132
+ to the model fusing function
126
133
:param num_calibration_steps: Number of steps to run post training calibration for.
127
134
When None, the entire calibration_dataloader is used
128
135
:param end_epoch: Disabled, setting to anything other than -1 will raise an
@@ -138,6 +145,8 @@ def __init__(
138
145
exclude_module_types : Optional [List [str ]] = None ,
139
146
disable_quantization_observer_epoch : Optional [float ] = None ,
140
147
freeze_bn_stats_epoch : Optional [float ] = None ,
148
+ model_fuse_fn_name : Optional [str ] = None ,
149
+ model_fuse_fn_kwargs : Optional [Dict [str , Any ]] = None ,
141
150
num_calibration_steps : Optional [int ] = None ,
142
151
end_epoch : float = - 1.0 ,
143
152
):
@@ -164,10 +173,20 @@ def __init__(
164
173
self ._calibration_dataloader = None
165
174
self ._calibration_function = None
166
175
176
+ self ._model_fuse_fn_name = model_fuse_fn_name
177
+ self ._model_fuse_fn_kwargs = model_fuse_fn_kwargs or {}
178
+ if (
179
+ isinstance (self ._model_fuse_fn_name , str )
180
+ and self ._model_fuse_fn_name .lower () == "none"
181
+ ):
182
+ self ._model_fuse_fn_name = None
183
+
167
184
self ._qat_enabled = False
168
185
self ._quantization_observer_disabled = False
169
186
self ._bn_stats_frozen = False
170
187
188
+ self ._validate_params ()
189
+
171
190
@BaseModifier .sparsification_types .getter
172
191
def sparsification_types (self ) -> List [SparsificationTypes ]:
173
192
"""
@@ -311,6 +330,40 @@ def num_calibration_steps(self, value: Optional[int]):
311
330
"""
312
331
self ._num_calibration_steps = value
313
332
333
+ @ModifierProp ()
334
+ def model_fuse_fn_name (self ) -> Optional [str ]:
335
+ """
336
+ :return: Name of model function to fuse the model in place prior
337
+ to performing QAT. None sets to default function.
338
+ If tensorrt flag is True, default is 'no_fuse', otherwise
339
+ `sparseml.pytorch.utils.fuse_module_conv_bn_relus`.
340
+ """
341
+ return self ._model_fuse_fn_name
342
+
343
+ @model_fuse_fn_name .setter
344
+ def model_fuse_fn_name (self , value : Optional [str ]):
345
+ """
346
+ :params value: Name of model function to fuse the model in place prior
347
+ to performing QAT. Set None to use the default function
348
+ `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Set as 'no_fuse'
349
+ to skip module fusing.
350
+ """
351
+ self ._model_fuse_fn_name = value
352
+ if (
353
+ isinstance (self ._model_fuse_fn_name , str )
354
+ and self ._model_fuse_fn_name .lower () == "none"
355
+ ):
356
+ self ._model_fuse_fn_name = None
357
+ self ._validate_params ()
358
+
359
+ @ModifierProp ()
360
+ def model_fuse_fn_kwargs (self ) -> Dict [str , Any ]:
361
+ """
362
+ :return: Dictionary of keyword arguments to be passed to the
363
+ model fuse function
364
+ """
365
+ return self ._model_fuse_fn_kwargs
366
+
314
367
def initialize (
315
368
self ,
316
369
module : Module ,
@@ -434,7 +487,7 @@ def _freeze_bn_stats_update_ready(self, epoch: float) -> bool:
434
487
435
488
def _enable_module_qat (self , module : Module ):
436
489
# fuse conv-bn-relu blocks prior to quantization emulation
437
- fuse_module_conv_bn_relus (module , inplace = True )
490
+ self . _fuse (module )
438
491
439
492
# add quantization_schemes to target submodules
440
493
set_quantization_schemes (
@@ -458,7 +511,22 @@ def _enable_module_qat(self, module: Module):
458
511
459
512
self ._calibrate_if_possible (module )
460
513
461
- def _calibrate_if_possible (self , module ):
514
+ def _fuse (self , module : Module ):
515
+ if self .model_fuse_fn_name in [None , "conv_bn_relus" ]:
516
+ self ._model_fuse_fn_kwargs ["inplace" ] = True
517
+ fuse_module_conv_bn_relus (module , ** self ._model_fuse_fn_kwargs )
518
+ elif self .model_fuse_fn_name != "no_fuse" :
519
+ module_fuse_fn = getattr (module , self ._model_fuse_fn_name , None )
520
+ if module_fuse_fn is None or not callable (module_fuse_fn ):
521
+ raise ValueError (
522
+ "Invalid model_fuse_fn_name. "
523
+ "Module has no callable function {}" .format (
524
+ self ._model_fuse_fn_name
525
+ )
526
+ )
527
+ module_fuse_fn (** self ._model_fuse_fn_kwargs )
528
+
529
+ def _calibrate_if_possible (self , module : Module ):
462
530
if self .num_calibration_steps == 0 and self ._calibration_dataloader :
463
531
warnings .warn (
464
532
f"num_calibration_steps is { self .num_calibration_steps } ."
@@ -477,7 +545,7 @@ def _calibrate_if_possible(self, module):
477
545
elif self ._calibration_dataloader :
478
546
self ._calibrate (module )
479
547
480
- def _calibrate (self , module ):
548
+ def _calibrate (self , module : Module ):
481
549
_LOGGER .info ("Running quantization calibration using calibration_dataloader" )
482
550
483
551
module_training = module .training
@@ -530,6 +598,20 @@ def _validate_params(self):
530
598
)
531
599
)
532
600
601
+ all_schemes = [self ._default_scheme ]
602
+ if self ._submodule_schemes :
603
+ all_schemes += list (self ._submodule_schemes .values ())
604
+ if self ._model_fuse_fn_kwargs :
605
+ all_schemes += list (self ._module_type_schemes .values ())
606
+ if any (scheme .target_hardware == "tensorrt" for scheme in all_schemes ) and (
607
+ self ._model_fuse_fn_name != "no_fuse"
608
+ ):
609
+ _LOGGER .info (
610
+ "QuantizationModifier - target hardware tensorrt detected - "
611
+ "Disabling model fuse step"
612
+ )
613
+ self ._model_fuse_fn_name = "no_fuse"
614
+
533
615
534
616
class _QuantizationSchemesDict (dict ):
535
617
# wrapper class for dict to override the __str__ method for yaml serialization
0 commit comments