Skip to content

Commit 7cfad32

Browse files
committed
Fixes to load pre-trained model w/ channel-wise quantization
1 parent 779af69 commit 7cfad32

File tree

1 file changed

+3
-3
lines changed
  • src/sparseml/pytorch/sparsification/quantization

1 file changed

+3
-3
lines changed

src/sparseml/pytorch/sparsification/quantization/helpers.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,7 @@ def initialize_channel_wise_scale_zp(module: Module):
725725
for name, submodule in module.named_modules():
726726
weight_fake_quant = getattr(submodule, "weight_fake_quant", None)
727727
if not weight_fake_quant or (
728-
getattr(weight_fake_quant, "qscheme", None) is not torch.per_channel_affine
728+
getattr(weight_fake_quant, "qscheme", None) not in [torch.per_channel_affine, torch.per_channel_symmetric]
729729
):
730730
# only consider modules with channel-wise quantized weights
731731
continue
@@ -743,11 +743,11 @@ def initialize_channel_wise_scale_zp(module: Module):
743743
# update scale and zero point if they are initialized to a size of 1
744744
scale = weight_fake_quant.scale
745745
if scale.numel() == 1:
746-
weight_fake_quant.scale = scale.reshape(-1).expand(num_channels)
746+
weight_fake_quant.scale = torch.ones(num_channels, dtype=scale.dtype)
747747

748748
zero_point = weight_fake_quant.zero_point
749749
if zero_point.numel() == 1:
750-
weight_fake_quant.zero_point = zero_point.reshape(-1).expand(num_channels)
750+
weight_fake_quant.scale = torch.ones(num_channels, dtype=zero_point.dtype)
751751

752752
# update the observer min and max vals
753753
if weight_fake_quant.activation_post_process.min_val.numel() == 0:

0 commit comments

Comments
 (0)