Skip to content

Commit 41a7890

Browse files
fix dtype bug and add logging
1 parent 02f061c commit 41a7890

File tree

2 files changed

+17
-26
lines changed

2 files changed

+17
-26
lines changed

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
_is_column_major,
2020
)
2121

22-
2322
logger: logging.Logger = logging.getLogger(__name__)
2423

2524

@@ -41,7 +40,7 @@ def _scaled_grouped_mm(
4140
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
4241
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
4342
"""
44-
logger.info("Using scaled_grouped_mm")
43+
logger.debug("Using scaled_grouped_mm")
4544
return _Float8GroupedMM.apply(
4645
A,
4746
B_t,

torchao/prototype/moe_training/tensor.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,11 @@
1313
from torch._prims_common import suggest_memory_format
1414
from torch.distributed.device_mesh import DeviceMesh
1515
from torch.distributed.fsdp import MixedPrecisionPolicy
16-
from torch.autograd.grad_mode import _unsafe_preserve_version_counter
1716

1817
from torchao.prototype.moe_training import _scaled_grouped_mm
1918

2019
logger: logging.Logger = logging.getLogger(__name__)
2120

22-
2321
_ops_to_preserve_subclass = {
2422
torch.ops.aten.empty_like.default,
2523
torch.ops.aten.new_zeros.default,
@@ -68,11 +66,14 @@ def __init__(
6866
tensor: torch.Tensor,
6967
dtype: torch.dtype,
7068
):
71-
self._data = tensor
69+
self._data = tensor.to(dtype)
7270
self._dtype = dtype
7371

7472
@classmethod
7573
def __torch_function__(cls, func, types, args, kwargs={}):
74+
logger.debug(
75+
f"ScaledGroupedMMTensor func: {func.__name__}, args: {args}, kwargs: {kwargs}"
76+
)
7677
# override the grouped mm op to use the differentiable _scaled_grouped_mm
7778
if func.__name__ == cls.grouped_mm_func_name:
7879
# Use torchao scaled grouped mm with dynamic quant for
@@ -103,17 +104,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs={}):
103104
if func == torch.ops.aten.detach.default:
104105
return ScaledGroupedMMTensor(args[0]._data, args[0]._dtype)
105106

106-
# unwrap args and kwargs
107-
dtype: Optional[torch.dtype] = None
108-
109-
def unwrap(t):
110-
nonlocal dtype
111-
if dtype is None:
112-
dtype = t._dtype
113-
else:
114-
assert t._dtype == dtype
115-
return t._data
116-
107+
# unwrap args/kwargs
108+
unwrap = lambda x: x._data if isinstance(x, ScaledGroupedMMTensor) else x
117109
args, kwargs = pytree.tree_map_only(
118110
ScaledGroupedMMTensor, unwrap, (args, kwargs or {})
119111
)
@@ -128,7 +120,7 @@ def unwrap(t):
128120
# wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
129121
return pytree.tree_map_only(
130122
torch.Tensor,
131-
lambda x: ScaledGroupedMMTensor(x, dtype),
123+
lambda x: ScaledGroupedMMTensor(x, x.dtype),
132124
out,
133125
)
134126

@@ -154,9 +146,11 @@ def fsdp_pre_all_gather(
154146
module: nn.Module,
155147
mp_policy: MixedPrecisionPolicy,
156148
):
157-
all_gather_inputs = (self._data.to(mp_policy.param_dtype),)
149+
all_gather_inputs = (self._data,)
158150
all_gather_metadata = ()
159-
logger.debug(f"fsdp_pre_all_gather: self._data.dtype={self._data.dtype}, param_dtype: {mp_policy.param_dtype}")
151+
logger.debug(
152+
f"ScaledGroupedMMTensor fsdp_pre_all_gather: self._data.dtype={self._data.dtype}, param_dtype: {mp_policy.param_dtype}"
153+
)
160154
return all_gather_inputs, all_gather_metadata
161155

162156
def fsdp_post_all_gather(
@@ -168,15 +162,13 @@ def fsdp_post_all_gather(
168162
out: Optional[torch.Tensor] = None,
169163
):
170164
(data,) = all_gather_outputs
171-
logger.debug(f"fsdp_post_all_gather: data.dtype={data.dtype}, param_dtype: {param_dtype}")
165+
logger.debug(
166+
f"ScaledGroupedMMTensor fsdp_post_all_gather: data.dtype={data.dtype}, param_dtype: {param_dtype}"
167+
)
172168

173169
if out is not None:
174-
#with _unsafe_preserve_version_counter(out):
175-
with torch.no_grad():
176-
out.copy_(data)
177170
return
178171

179-
upcast_data = data.to(param_dtype)
180-
output = ScaledGroupedMMTensor(upcast_data, param_dtype)
181-
inner_tensors = (upcast_data,)
172+
output = ScaledGroupedMMTensor(data, param_dtype)
173+
inner_tensors = (data,)
182174
return output, inner_tensors

0 commit comments

Comments
 (0)