Skip to content

Commit ce8cd8e

Browse files
committed
Register choose_qparams_affine_float8 as custom op
1 parent 6dfba04 commit ce8cd8e

File tree

3 files changed

+29
-3
lines changed

3 files changed

+29
-3
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def test_mm_float8dq_per_row(
356356
)
357357
@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
358358
@common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16])
359-
@common_utils.parametrize("block_size", [None, (1, 32), (2, 16), (4, 8)])
359+
@common_utils.parametrize("block_size", [(), (1, 32), (2, 16), (4, 8)])
360360
def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size):
361361
"""Test _dequantize_affine_float8 with various configurations"""
362362

test/integration/test_integration.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import torch
1616
import torch.nn as nn
17+
from torch.testing import FileCheck
1718
from parameterized import parameterized
1819
from torch._dynamo import config
1920
from torch._inductor.utils import run_and_get_code
@@ -41,6 +42,7 @@
4142
change_linear_weights_to_int4_woqtensors,
4243
change_linear_weights_to_int8_dqtensors,
4344
change_linear_weights_to_int8_woqtensors,
45+
Float8DynamicActivationFloat8WeightConfig,
4446
int4_weight_only,
4547
int8_dynamic_activation_int4_weight,
4648
int8_dynamic_activation_int8_weight,
@@ -2077,6 +2079,29 @@ def forward(self, x):
20772079
self.assertTrue(torch.ops.torchao.quantize_affine.default in targets)
20782080
self.assertFalse(torch.ops.aten.narrow.default in targets)
20792081

2082+
@unittest.skipIf(
2083+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
2084+
)
2085+
def test_export_float8(self):
2086+
class SimpleNetwork(torch.nn.Module):
2087+
def __init__(self):
2088+
super(SimpleNetwork, self).__init__()
2089+
self.linear = torch.nn.Linear(in_features=32, out_features=16, bias=False)
2090+
2091+
def forward(self, x):
2092+
return self.linear(x)
2093+
2094+
model= SimpleNetwork().eval().cuda()
2095+
inp = torch.randn(2, 32).cuda()
2096+
config = Float8DynamicActivationFloat8WeightConfig()
2097+
quantize_(model, config)
2098+
2099+
ep = torch.export.export(model, (inp,))
2100+
print(ep)
2101+
FileCheck().check_count("torch.ops.torchao.choose_qparams_affine_float8.default", 1, exactly=True).run(
2102+
str(ep.graph)
2103+
)
2104+
20802105

20812106
class TestUtils(unittest.TestCase):
20822107
@parameterized.expand(

torchao/quantization/quant_primitives.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2178,11 +2178,12 @@ def _dequantize_affine_floatx(
21782178
return tensor
21792179

21802180

2181+
@register_custom_op
21812182
def _choose_qparams_affine_float8(
21822183
tensor: torch.Tensor,
2184+
block_size: List[int],
21832185
float8_dtype: torch.dtype = torch.float8_e4m3fn,
21842186
scale_dtype: torch.dtype = torch.float32,
2185-
block_size: Optional[Tuple[int, ...]] = None,
21862187
) -> torch.Tensor:
21872188
"""
21882189
Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity.
@@ -2195,7 +2196,7 @@ def _choose_qparams_affine_float8(
21952196
"""
21962197
quant_max = torch.finfo(float8_dtype).max
21972198
# only tensorwise scaling is supported for now:
2198-
if block_size is None:
2199+
if len(block_size) == 0:
21992200
max_abs = tensor.abs().max()
22002201
scale = max_abs / quant_max
22012202
else:

0 commit comments

Comments
 (0)