Skip to content

Support Int4 ONNX Export #1670

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jul 19, 2023
8 changes: 7 additions & 1 deletion src/sparseml/pytorch/torch_to_onnx_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
from sparseml.exporters.transforms.base_transform import BaseTransform
from sparseml.pytorch import _PARSED_TORCH_VERSION
from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET
from sparseml.pytorch.utils.helpers import tensors_module_forward, tensors_to_device
from sparseml.pytorch.utils.helpers import (
adjust_quantization_for_onnx_export,
tensors_module_forward,
tensors_to_device,
)
from sparseml.pytorch.utils.model import is_parallel_model
from sparsezoo.utils import save_onnx

Expand Down Expand Up @@ -190,6 +194,8 @@ def transform(
)
}

adjust_quantization_for_onnx_export(module) # in-place operation

# disable active quantization observers because they cannot be exported
disabled_observers = []
for submodule in module.modules():
Expand Down
9 changes: 7 additions & 2 deletions src/sparseml/pytorch/utils/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from sparseml.onnx.utils import ONNXGraph
from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET
from sparseml.pytorch.utils.helpers import (
adjust_quantization_for_onnx_export,
tensors_export,
tensors_module_forward,
tensors_to_device,
Expand Down Expand Up @@ -208,17 +209,21 @@ def export_onnx(
"""
if not export_kwargs:
export_kwargs = {}

module = deepcopy(self._module).cpu() # don't modify the original model
if "output_names" not in export_kwargs:
sample_batch = tensors_to_device(sample_batch, "cpu")
module = deepcopy(self._module).cpu()
module.eval()
with torch.no_grad():
out = tensors_module_forward(
sample_batch, module, check_feat_lab_inp=False
)
export_kwargs["output_names"] = self.get_output_names(out)

adjust_quantization_for_onnx_export(module) # in-place operation

export_onnx(
module=self._module,
module=module,
sample_batch=sample_batch,
file_path=os.path.join(self._output_dir, name),
opset=opset,
Expand Down
41 changes: 41 additions & 0 deletions src/sparseml/pytorch/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import numpy
import torch
from packaging import version
from torch import Tensor
from torch.nn import Linear, Module, Parameter
from torch.nn.modules.conv import Conv2d, Conv3d, _ConvNd
Expand Down Expand Up @@ -101,10 +102,12 @@
"memory_aware_threshold",
"download_framework_model_by_recipe_type",
"detach",
"adjust_quantization_for_onnx_export",
]


_LOGGER = logging.getLogger(__name__)
_PARSED_TORCH_VERSION = version.parse(torch.__version__)


##############################
Expand Down Expand Up @@ -1174,3 +1177,41 @@ def detach(x: Union[torch.Tensor, List, Tuple]):
return tuple([detach(e) for e in x])
else:
raise ValueError("Unexpected type to detach")


def adjust_quantization_for_onnx_export(module: torch.nn.Module) -> torch.nn.Module:
# supported pytorch ranges are int8 or uint8
allowed_ranges = [(0, 127), (0, 255), (-128, 127)]
fake_quant_modules = [
m for m in module.modules() if m.__class__.__name__ == "FakeQuantize"
]

if _PARSED_TORCH_VERSION >= version.parse("1.12"):
for quant in fake_quant_modules:
# original ranges preserved in quant.quant_min and quant.quant_max
quant_range = (
quant.activation_post_process.quant_min,
quant.activation_post_process.quant_max,
)
if quant_range not in allowed_ranges:
if quant_range[0] < 0: # convert signed range to int8
quant.activation_post_process.quant_min = -128
quant.activation_post_process.quant_max = 127
else: # convert unsigned range to uint8
quant.activation_post_process.quant_min = 0
quant.activation_post_process.quant_max = 255
# don't update observer since ranges are artificially modified
quant.observer_enabled[0] = 0

else: # backwards compatibility for torch <= 1.11
for quant in fake_quant_modules:
quant_range = (quant.quant_min, quant.quant_max)
if quant_range not in allowed_ranges:
if quant_range[0] < 0: # convert signed range to int8
quant.quant_min = -128
quant.quant_max = 127
else: # convert unsigned range to uint8
quant.quant_min = 0
quant.quant_max = 255
# don't update observer since ranges are artificially modified
quant.observer_enabled[0] = 0
117 changes: 117 additions & 0 deletions tests/sparseml/pytorch/test_torch_to_onnx_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,131 @@
import torch

from sparseml.exporters.onnx_to_deepsparse import ONNXToDeepsparse
from sparseml.onnx.utils.helpers import get_init_by_name
from sparseml.pytorch.models.registry import ModelRegistry
from sparseml.pytorch.optim import ScheduledModifierManager
from sparseml.pytorch.sparsification.quantization import QuantizationModifier
from sparseml.pytorch.torch_to_onnx_exporter import TorchToONNX
from sparseml.pytorch.utils import ModuleExporter
from sparsezoo.utils import validate_onnx
from tests.sparseml.pytorch.helpers import ConvNet, LinearNet, MLPNet


QUANT_RECIPE = """
!QuantizationModifier
start_epoch: 0.0
scheme:
input_activations:
num_bits: 8
symmetric: False
weights:
num_bits: 4
symmetric: True
"""


def _get_4bit_modules(model):
fake_quant_modules = [
module
for module in model.modules()
if module.__class__.__name__ == "FakeQuantize"
]
int4_fake_quant_modules = [
quant_module
for quant_module in fake_quant_modules
if quant_module.activation_post_process.quant_min == -8
and quant_module.activation_post_process.quant_max == 7
]

return int4_fake_quant_modules


def _get_conv_quant_ranges(onnx_model):
conv_ranges = {}
for node in onnx_model.graph.node:
if node.op_type == "ConvInteger":
x, w, x_zero_point, w_zero_point = node.input
zero_value = get_init_by_name(onnx_model, w_zero_point)
zero = onnx.numpy_helper.to_array(zero_value)
weights_value = get_init_by_name(onnx_model, w)
weights = onnx.numpy_helper.to_array(weights_value)
converted = (weights - zero).astype("int8")
cmin, cmax = converted.min(), converted.max()
range = cmax - cmin
conv_ranges[node.name] = range

return conv_ranges


@pytest.mark.parametrize(
"model,sample_batch",
[
(MLPNet(), torch.randn(8)),
(MLPNet(), torch.randn(10, 8)),
(LinearNet(), torch.randn(8)),
(LinearNet(), torch.randn(10, 8)),
(ConvNet(), torch.randn(1, 3, 28, 28)),
],
)
def test_export_4bit_model(tmp_path, model, sample_batch):
old_dir = tmp_path / "old_exporter"
old_dir.mkdir()
new_dir = tmp_path / "new_exporter"
new_dir.mkdir()

manager = ScheduledModifierManager.from_yaml(QUANT_RECIPE)
manager.apply(model)

# ensure 4bit quantization correctly applied
num_4bit_modules = len(_get_4bit_modules(model))
assert num_4bit_modules > 0

new_exporter = TorchToONNX(sample_batch)
new_exporter.export(model, new_dir / "model.onnx")
ONNXToDeepsparse(use_qlinear_conv=True).export(
new_dir / "model.onnx", new_dir / "model.onnx"
)
validate_onnx(str(new_dir / "model.onnx"))

# ensure export didn't modify original model
assert len(_get_4bit_modules(model)) == num_4bit_modules

old_exporter = ModuleExporter(model, old_dir)
old_exporter.export_onnx(sample_batch, convert_qat=True)
validate_onnx(str(old_dir / "model.onnx"))

# ensure export didn't modify original model
assert len(_get_4bit_modules(model)) == num_4bit_modules


def test_export_4bit_model_range(tmp_path):
model, sample_batch = ConvNet(), torch.randn(1, 3, 28, 28)
old_dir = tmp_path / "old_exporter"
old_dir.mkdir()
new_dir = tmp_path / "new_exporter"
new_dir.mkdir()

manager = ScheduledModifierManager.from_yaml(QUANT_RECIPE)
manager.apply(model)

new_exporter = TorchToONNX(sample_batch)
new_exporter.export(model, new_dir / "model.onnx")
ONNXToDeepsparse(use_qlinear_conv=True).export(
new_dir / "model.onnx", new_dir / "model.onnx"
)
onnx_model_new = onnx.load(new_dir / "model.onnx")
conv_quant_ranges = _get_conv_quant_ranges(onnx_model_new)
# all ConvInteger blocks should be quantized to int4
assert all([conv_range <= 16 for name, conv_range in conv_quant_ranges.items()])

old_exporter = ModuleExporter(model, old_dir)
old_exporter.export_onnx(sample_batch, convert_qat=True)
onnx_model_old = onnx.load(old_dir / "model.onnx")
conv_quant_ranges = _get_conv_quant_ranges(onnx_model_old)
# all ConvInteger blocks should be quantized to int4
assert all([conv_range <= 16 for name, conv_range in conv_quant_ranges.items()])


@pytest.mark.parametrize(
"model,sample_batch",
[
Expand Down