Skip to content

Register choose_qparams_affine_float8 as custom op #2461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def test_mm_float8dq_per_row(
)
@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
@common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16])
@common_utils.parametrize("block_size", [None, (1, 32), (2, 16), (4, 8)])
@common_utils.parametrize("block_size", [(), (1, 32), (2, 16), (4, 8)])
def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size):
"""Test _dequantize_affine_float8 with various configurations"""

Expand Down
28 changes: 28 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from parameterized import parameterized
from torch._dynamo import config
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck

import torchao
from torchao.dtypes import Int4CPULayout, Int4XPULayout, TensorCoreTiledLayout
Expand All @@ -37,6 +38,7 @@

# APIs to be deprecated (used for torch 2.2.2 and 2.3)
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
_replace_with_custom_fn_if_matches_filter,
change_linear_weights_to_int4_woqtensors,
change_linear_weights_to_int8_dqtensors,
Expand Down Expand Up @@ -86,6 +88,7 @@
check_cpu_version,
check_xpu_version,
is_fbcode,
is_sm_at_least_89,
is_sm_at_least_90,
unwrap_tensor_subclass,
)
Expand Down Expand Up @@ -2077,6 +2080,31 @@ def forward(self, x):
self.assertTrue(torch.ops.torchao.quantize_affine.default in targets)
self.assertFalse(torch.ops.aten.narrow.default in targets)

@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
def test_export_float8(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need this skip according to CI:

@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)

class SimpleNetwork(torch.nn.Module):
def __init__(self):
super(SimpleNetwork, self).__init__()
self.linear = torch.nn.Linear(
in_features=32, out_features=16, bias=False
)

def forward(self, x):
return self.linear(x)

model = SimpleNetwork().eval().cuda()
inp = torch.randn(2, 32).cuda()
config = Float8DynamicActivationFloat8WeightConfig()
quantize_(model, config)

ep = torch.export.export(model, (inp,))
print(ep)
FileCheck().check_count(
"torch.ops.torchao.choose_qparams_affine_float8.default", 1, exactly=True
).run(str(ep.graph))


class TestUtils(unittest.TestCase):
@parameterized.expand(
Expand Down
5 changes: 3 additions & 2 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -2178,11 +2178,12 @@ def _dequantize_affine_floatx(
return tensor


@register_custom_op
def _choose_qparams_affine_float8(
tensor: torch.Tensor,
block_size: List[int],
float8_dtype: torch.dtype = torch.float8_e4m3fn,
scale_dtype: torch.dtype = torch.float32,
block_size: Optional[Tuple[int, ...]] = None,
) -> torch.Tensor:
"""
Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity.
Expand All @@ -2195,7 +2196,7 @@ def _choose_qparams_affine_float8(
"""
quant_max = torch.finfo(float8_dtype).max
# only tensorwise scaling is supported for now:
if block_size is None:
if len(block_size) == 0:
max_abs = tensor.abs().max()
scale = max_abs / quant_max
else:
Expand Down
Loading