Skip to content

Commit dc0803c

Browse files
committed
Update
[ghstack-poisoned]
1 parent dd36237 commit dc0803c

File tree

4 files changed

+81
-38
lines changed

4 files changed

+81
-38
lines changed

test/prototype/mx_formats/test_mx_dtensor.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,25 @@ def _test_dtensor_cast_to_mxfp8(mesh: DeviceMesh, size=4):
6969

7070

7171
def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=128):
72+
config = MXLinearConfig.from_recipe_name("mxfp8_emulated")
73+
config.block_size = 32
74+
_test_lowp_mlp_tensor_parallelism_base(
75+
mesh, config, size, compile=False, allgather_in_lowp=False
76+
)
77+
_test_lowp_mlp_tensor_parallelism_base(
78+
mesh, config, size, compile=True, allgather_in_lowp=False
79+
)
80+
81+
82+
def _test_mxfp8_mlp_tensor_parallelism_dim1_triton(mesh: DeviceMesh, size=128):
7283
config = MXLinearConfig.from_recipe_name("mxfp8_emulated")
7384
config.block_size = 32
7485
config.use_fp8_dim1_cast_triton_kernel = True
7586
_test_lowp_mlp_tensor_parallelism_base(
7687
mesh, config, size, compile=False, allgather_in_lowp=False
7788
)
89+
# TODO(future PR): enable compile here, currently seeing
90+
# https://www.internalfb.com/phabricator/paste/view/P1851219639
7891
# _test_lowp_mlp_tensor_parallelism_base(
7992
# mesh, config, size, compile=True, allgather_in_lowp=False
8093
# )
@@ -83,8 +96,9 @@ def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=128):
8396
if __name__ == "__main__":
8497
device_mesh = setup_distributed()
8598
tests = [
86-
# _test_dtensor_cast_to_mxfp8,
99+
_test_dtensor_cast_to_mxfp8,
87100
_test_mxfp8_mlp_tensor_parallelism,
101+
_test_mxfp8_mlp_tensor_parallelism_dim1_triton,
88102
]
89103

90104
for test in tqdm(tests, desc="Running tests"):

torchao/prototype/mx_formats/kernels.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,10 +1363,21 @@ def triton_to_mxfp8_dim1(
13631363
output_col_major.t(),
13641364
col_scale.view(torch.float8_e8m0fnu),
13651365
)
1366-
1367-
print('ASDFASDFASDF')
1368-
from torchao import triton_to_mxfp8_dim1
1369-
print(triton_to_mxfp8_dim1)
1366+
1367+
# print(torch.ops.torchao.triton_to_mxfp8_dim1.default)
1368+
1369+
from torch.distributed.tensor import Replicate, Shard
1370+
from torch.distributed.tensor.experimental import register_sharding
1371+
1372+
@register_sharding(torch.ops.torchao.triton_to_mxfp8_dim1.default)
1373+
def custom_triton_to_mxfp8_dim1_sharding(x, inner_block_size=32):
1374+
replicate = ([Replicate(), Replicate()], [Replicate(), None])
1375+
# Note that the data is returned transposed, which is why
1376+
# we flip the sharding dim below
1377+
shard_dim0 = ([Shard(1), Shard(1)], [Shard(0), None])
1378+
shard_dim1 = ([Shard(0), Shard(0)], [Shard(1), None])
1379+
acceptable_shardings = [replicate, shard_dim0, shard_dim1]
1380+
return acceptable_shardings
13701381

13711382
def triton_to_mxfp8_dim1_reference(
13721383
x_hp: torch.Tensor, block_size

torchao/prototype/mx_formats/mx_linear.py

Lines changed: 49 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import torch
1414
import torch.nn.functional as F
15+
from torch.distributed._tensor import DTensor
1516

1617
from torchao.prototype.mx_formats.config import (
1718
MXGemmKernelChoice,
@@ -25,6 +26,46 @@
2526
)
2627

2728

29+
def _triton_to_mxfp8_dim1_wrapper(
30+
a, block_size, elem_dtype, hp_dtype, gemm_kernel_choice
31+
):
32+
a_data, a_scale = triton_to_mxfp8_dim1(a, block_size)
33+
if isinstance(a_data, DTensor):
34+
assert isinstance(a_scale, DTensor)
35+
a_data_local = a_data.to_local()
36+
a_scale_local = a_scale.to_local()
37+
inner = MXTensor(
38+
a_scale_local,
39+
a_data_local.t(),
40+
elem_dtype,
41+
block_size,
42+
hp_dtype,
43+
False,
44+
gemm_kernel_choice,
45+
False,
46+
)
47+
mx_tensor = DTensor.from_local(
48+
inner,
49+
a_data.device_mesh,
50+
a_data.placements,
51+
run_check=False,
52+
shape=a_data.t().size(),
53+
stride=a_data.t().stride(),
54+
)
55+
else:
56+
mx_tensor = MXTensor(
57+
a_scale,
58+
a_data.t(),
59+
elem_dtype,
60+
block_size,
61+
hp_dtype,
62+
False,
63+
gemm_kernel_choice,
64+
False,
65+
)
66+
return mx_tensor
67+
68+
2869
@torch._dynamo.allow_in_graph
2970
class mx_mm(torch.autograd.Function):
3071
# There are three gemms in a forward + backward of a Linear layer:
@@ -95,20 +136,9 @@ def backward(ctx, grad_output_hp: torch.Tensor):
95136
)
96137

97138
if use_fp8_dim1_cast_triton_kernel:
98-
weight_mx_dim1_data, weight_mx_dim1_scale = triton_to_mxfp8_dim1(
99-
weight_hp, block_size
139+
weight_mx_dim1 = _triton_to_mxfp8_dim1_wrapper(
140+
weight_hp, block_size, w_elem_dtype, weight_hp.dtype, gemm_kernel_choice
100141
)
101-
weight_mx_dim1 = MXTensor(
102-
weight_mx_dim1_scale.reshape(-1),
103-
weight_mx_dim1_data.t(),
104-
w_elem_dtype,
105-
block_size,
106-
weight_hp.dtype,
107-
False,
108-
gemm_kernel_choice,
109-
False,
110-
)
111-
112142
else:
113143
weight_hp_t_c = weight_hp.t().contiguous()
114144
weight_mx_dim1 = MXTensor.to_mx(
@@ -124,18 +154,12 @@ def backward(ctx, grad_output_hp: torch.Tensor):
124154

125155
# input_t @ grad_output = grad_weight
126156
if use_fp8_dim1_cast_triton_kernel:
127-
grad_output_mx_dim1_data, grad_output_mx_dim1_scale = triton_to_mxfp8_dim1(
128-
grad_output_hp_r, block_size
129-
)
130-
grad_output_mx_dim1 = MXTensor(
131-
grad_output_mx_dim1_scale.reshape(-1),
132-
grad_output_mx_dim1_data.t(),
133-
grad_elem_dtype,
157+
grad_output_mx_dim1 = _triton_to_mxfp8_dim1_wrapper(
158+
grad_output_hp_r,
134159
block_size,
160+
grad_elem_dtype,
135161
grad_output_hp_r.dtype,
136-
False,
137162
gemm_kernel_choice,
138-
False,
139163
)
140164
else:
141165
grad_output_mx_dim1 = MXTensor.to_mx(
@@ -146,18 +170,12 @@ def backward(ctx, grad_output_hp: torch.Tensor):
146170
)
147171

148172
if use_fp8_dim1_cast_triton_kernel:
149-
input_t_mx_dim0_tmp_data, input_t_mx_dim0_tmp_scale = triton_to_mxfp8_dim1(
150-
input_hp_r, block_size
151-
)
152-
input_t_mx_dim0_tmp = MXTensor(
153-
input_t_mx_dim0_tmp_scale.reshape(-1),
154-
input_t_mx_dim0_tmp_data.t(),
155-
in_elem_dtype,
173+
input_t_mx_dim0_tmp = _triton_to_mxfp8_dim1_wrapper(
174+
input_hp_r,
156175
block_size,
176+
in_elem_dtype,
157177
input_hp_r.dtype,
158-
False,
159178
gemm_kernel_choice,
160-
False,
161179
)
162180
input_t_mx_dim0 = input_t_mx_dim0_tmp.t()
163181
else:

torchao/testing/training/dtensor_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ def _test_lowp_mlp_tensor_parallelism_base(
151151
sp_model = torch.compile(sp_model)
152152
sp_model2 = torch.compile(sp_model2)
153153

154-
x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False)
155-
go_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False)
154+
x_fp32 = torch.rand(1, size * 2, size, device=device, requires_grad=False)
155+
go_fp32 = torch.rand(1, size * 2, size, device=device, requires_grad=False)
156156
x_fp32_tp_input = x_fp32.clone()
157157
go_fp32_tp = go_fp32.clone()
158158
x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)])

0 commit comments

Comments
 (0)