Skip to content

[moe training] Cast to mixed precision policy param dtype in fsdp_pre_all_gather hook #2455

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions torchao/prototype/moe_training/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _swap_params(
f"Does not support a root nn.Parameter with children: {module}"
)
if not isinstance(module.data, ScaledGroupedMMTensor):
new_data = ScaledGroupedMMTensor(module.data, module.data.dtype)
new_data = ScaledGroupedMMTensor(module.data)
return nn.Parameter(new_data, requires_grad=module.requires_grad)
return module

Expand All @@ -110,7 +110,7 @@ def post_order_traversal(
for param_name, param in module.named_parameters(recurse=False):
if not isinstance(param.data, ScaledGroupedMMTensor):
new_param = nn.Parameter(
ScaledGroupedMMTensor(param.data, param.data.dtype),
ScaledGroupedMMTensor(param.data),
requires_grad=param.requires_grad,
)
setattr(module, param_name, new_param)
Expand Down
5 changes: 4 additions & 1 deletion torchao/prototype/moe_training/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import Optional

import torch
Expand All @@ -18,6 +19,8 @@
_is_column_major,
)

logger: logging.Logger = logging.getLogger(__name__)


def _scaled_grouped_mm(
A: torch.Tensor,
Expand All @@ -36,8 +39,8 @@ def _scaled_grouped_mm(
and in column-major memory layout.
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
use_triton_for_per_group_scales (bool): Whether to use custom triton kernels to compute per-group scales. Default is True.
"""
# logger.info("Using scaled_grouped_mm")
return _Float8GroupedMM.apply(
A,
B_t,
Expand Down
66 changes: 41 additions & 25 deletions torchao/prototype/moe_training/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@

import torch
import torch.utils._pytree as pytree
from torch import nn
from torch._prims_common import suggest_memory_format
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import MixedPrecisionPolicy

from torchao.prototype.moe_training import _scaled_grouped_mm

logger: logging.Logger = logging.getLogger(__name__)


_ops_to_preserve_subclass = {
torch.ops.aten.empty_like.default,
torch.ops.aten.new_zeros.default,
Expand Down Expand Up @@ -44,15 +46,14 @@ class ScaledGroupedMMTensor(torch.Tensor):
def __new__(
cls,
tensor: torch.Tensor,
dtype: torch.dtype,
):
return torch.Tensor._make_wrapper_subclass(
cls,
tensor.size(),
strides=tensor.stride(),
storage_offset=tensor.storage_offset(),
memory_format=suggest_memory_format(tensor),
dtype=dtype,
dtype=tensor.dtype,
layout=tensor.layout,
device=tensor.device,
pin_memory=tensor.is_pinned(),
Expand All @@ -62,14 +63,11 @@ def __new__(
def __init__(
self,
tensor: torch.Tensor,
dtype: torch.dtype,
):
self._data = tensor
self._dtype = dtype

@classmethod
def __torch_function__(cls, func, types, args, kwargs={}):
logger.info(f"{func.__name__}, args: {args}, kwargs: {kwargs}")
# override the grouped mm op to use the differentiable _scaled_grouped_mm
if func.__name__ == cls.grouped_mm_func_name:
# Use torchao scaled grouped mm with dynamic quant for
Expand Down Expand Up @@ -98,19 +96,10 @@ def __torch_function__(cls, func, types, args, kwargs={}):
def __torch_dispatch__(cls, func, types, args, kwargs={}):
# detach is special case
if func == torch.ops.aten.detach.default:
return ScaledGroupedMMTensor(args[0]._data, args[0]._dtype)

# unwrap args and kwargs
dtype: Optional[torch.dtype] = None

def unwrap(t):
nonlocal dtype
if dtype is None:
dtype = t._dtype
else:
assert t._dtype == dtype
return t._data
return ScaledGroupedMMTensor(args[0]._data)

# unwrap args/kwargs
unwrap = lambda x: x._data if isinstance(x, ScaledGroupedMMTensor) else x
args, kwargs = pytree.tree_map_only(
ScaledGroupedMMTensor, unwrap, (args, kwargs or {})
)
Expand All @@ -125,25 +114,33 @@ def unwrap(t):
# wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
return pytree.tree_map_only(
torch.Tensor,
lambda x: ScaledGroupedMMTensor(x, dtype),
lambda x: ScaledGroupedMMTensor(x),
out,
)

def __repr__(self):
return f"ScaledGroupedMMTensor(data={self._data}, dtype={self._dtype})"
return f"ScaledGroupedMMTensor(data={self._data})"

def __tensor_flatten__(self):
return ["_data"], {"_dtype": self._dtype}
return ["_data"]

@staticmethod
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
return ScaledGroupedMMTensor(
inner_tensors["_data"],
flatten_spec["_dtype"],
)

def fsdp_pre_all_gather(self, mesh):
all_gather_inputs = (self._data,)
# fsdp hooks based on https://github.com/pytorch/pytorch/blob/20e40492b046b9287726d3ec656117e4dc38f0e2/test/distributed/_composable/fsdp/test_fully_shard_extensions.py#L81
def fsdp_pre_all_gather(
self,
mesh: DeviceMesh,
outer_size: torch.Size,
outer_stride: tuple[int, ...],
module: nn.Module,
mp_policy: MixedPrecisionPolicy,
):
# cast to mixed precision dtype prior to all-gather
all_gather_inputs = (self._data.to(mp_policy.param_dtype),)
all_gather_metadata = ()
return all_gather_inputs, all_gather_metadata

Expand All @@ -156,6 +153,25 @@ def fsdp_post_all_gather(
out: Optional[torch.Tensor] = None,
):
(data,) = all_gather_outputs
output = ScaledGroupedMMTensor(data, param_dtype)

# For training step 1+, out=unsharded param, so we need to copy data to `out`
# if `self._data`` and `out` do not share the same storage.
# Otherwise, if they do share the same storage, we can just return directly.
if out is not None:
assert isinstance(out, ScaledGroupedMMTensor), f"{type(out)}"
if data.dtype == param_dtype:
assert (
data.untyped_storage().data_ptr()
== out._data.untyped_storage().data_ptr()
)
else:
assert out._data.dtype == param_dtype, (
f"{out._data.dtype} {param_dtype}"
)
out._data.copy_(data)
return

# For training step 0, out=None, so we need to return a new ScaledGroupedMMTensor.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have a test for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a test for float MoE + FSDP training. We don't have a test verifying which code branch is followed in this fsdp_post_all_gather hook at training step 0 vs 1, but I think the FSDP test alone is sufficient. Let me know if you have other thoughts.

output = ScaledGroupedMMTensor(data)
inner_tensors = (data,)
return output, inner_tensors
Loading