13
13
from torch ._prims_common import suggest_memory_format
14
14
from torch .distributed .device_mesh import DeviceMesh
15
15
from torch .distributed .fsdp import MixedPrecisionPolicy
16
- from torch .autograd .grad_mode import _unsafe_preserve_version_counter
17
16
18
17
from torchao .prototype .moe_training import _scaled_grouped_mm
19
18
20
19
logger : logging .Logger = logging .getLogger (__name__ )
21
20
22
-
23
21
_ops_to_preserve_subclass = {
24
22
torch .ops .aten .empty_like .default ,
25
23
torch .ops .aten .new_zeros .default ,
@@ -68,11 +66,14 @@ def __init__(
68
66
tensor : torch .Tensor ,
69
67
dtype : torch .dtype ,
70
68
):
71
- self ._data = tensor
69
+ self ._data = tensor . to ( dtype )
72
70
self ._dtype = dtype
73
71
74
72
@classmethod
75
73
def __torch_function__ (cls , func , types , args , kwargs = {}):
74
+ logger .debug (
75
+ f"ScaledGroupedMMTensor func: { func .__name__ } , args: { args } , kwargs: { kwargs } "
76
+ )
76
77
# override the grouped mm op to use the differentiable _scaled_grouped_mm
77
78
if func .__name__ == cls .grouped_mm_func_name :
78
79
# Use torchao scaled grouped mm with dynamic quant for
@@ -103,17 +104,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs={}):
103
104
if func == torch .ops .aten .detach .default :
104
105
return ScaledGroupedMMTensor (args [0 ]._data , args [0 ]._dtype )
105
106
106
- # unwrap args and kwargs
107
- dtype : Optional [torch .dtype ] = None
108
-
109
- def unwrap (t ):
110
- nonlocal dtype
111
- if dtype is None :
112
- dtype = t ._dtype
113
- else :
114
- assert t ._dtype == dtype
115
- return t ._data
116
-
107
+ # unwrap args/kwargs
108
+ unwrap = lambda x : x ._data if isinstance (x , ScaledGroupedMMTensor ) else x
117
109
args , kwargs = pytree .tree_map_only (
118
110
ScaledGroupedMMTensor , unwrap , (args , kwargs or {})
119
111
)
@@ -128,7 +120,7 @@ def unwrap(t):
128
120
# wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
129
121
return pytree .tree_map_only (
130
122
torch .Tensor ,
131
- lambda x : ScaledGroupedMMTensor (x , dtype ),
123
+ lambda x : ScaledGroupedMMTensor (x , x . dtype ),
132
124
out ,
133
125
)
134
126
@@ -154,9 +146,11 @@ def fsdp_pre_all_gather(
154
146
module : nn .Module ,
155
147
mp_policy : MixedPrecisionPolicy ,
156
148
):
157
- all_gather_inputs = (self ._data . to ( mp_policy . param_dtype ) ,)
149
+ all_gather_inputs = (self ._data ,)
158
150
all_gather_metadata = ()
159
- logger .debug (f"fsdp_pre_all_gather: self._data.dtype={ self ._data .dtype } , param_dtype: { mp_policy .param_dtype } " )
151
+ logger .debug (
152
+ f"ScaledGroupedMMTensor fsdp_pre_all_gather: self._data.dtype={ self ._data .dtype } , param_dtype: { mp_policy .param_dtype } "
153
+ )
160
154
return all_gather_inputs , all_gather_metadata
161
155
162
156
def fsdp_post_all_gather (
@@ -168,15 +162,13 @@ def fsdp_post_all_gather(
168
162
out : Optional [torch .Tensor ] = None ,
169
163
):
170
164
(data ,) = all_gather_outputs
171
- logger .debug (f"fsdp_post_all_gather: data.dtype={ data .dtype } , param_dtype: { param_dtype } " )
165
+ logger .debug (
166
+ f"ScaledGroupedMMTensor fsdp_post_all_gather: data.dtype={ data .dtype } , param_dtype: { param_dtype } "
167
+ )
172
168
173
169
if out is not None :
174
- #with _unsafe_preserve_version_counter(out):
175
- with torch .no_grad ():
176
- out .copy_ (data )
177
170
return
178
171
179
- upcast_data = data .to (param_dtype )
180
- output = ScaledGroupedMMTensor (upcast_data , param_dtype )
181
- inner_tensors = (upcast_data ,)
172
+ output = ScaledGroupedMMTensor (data , param_dtype )
173
+ inner_tensors = (data ,)
182
174
return output , inner_tensors
0 commit comments