@@ -225,9 +225,9 @@ def __init__(self,
225
225
226
226
@staticmethod
227
227
def adamw (params : List [torch .Tensor ], grads : List [torch .Tensor ], exp_avgs : List [torch .Tensor ],
228
- exp_avg_sqs : List [torch .Tensor ], max_exp_avg_sqs : List [torch .Tensor ], state_steps : List [int ], * ,
229
- amsgrad : bool , beta1 : float , beta2 : float , lr : float , initial_lr : float , weight_decay : float ,
230
- eps : float ) -> None :
228
+ exp_avg_sqs : List [torch .Tensor ], max_exp_avg_sqs : List [torch .Tensor ], state_steps : List [int ],
229
+ masks_on : List [ bool ], * , amsgrad : bool , beta1 : float , beta2 : float , lr : float , initial_lr : float ,
230
+ weight_decay : float , eps : float ) -> None :
231
231
r"""Functional API that performs AdamW algorithm computation with decoupled weight decay.
232
232
233
233
Args:
@@ -250,6 +250,7 @@ def adamw(params: List[torch.Tensor], grads: List[torch.Tensor], exp_avgs: List[
250
250
exp_avg = exp_avgs [i ]
251
251
exp_avg_sq = exp_avg_sqs [i ]
252
252
step = state_steps [i ]
253
+ mask_on = masks_on [i ]
253
254
254
255
# Perform stepweight decay
255
256
if weight_decay != 0 :
@@ -259,20 +260,32 @@ def adamw(params: List[torch.Tensor], grads: List[torch.Tensor], exp_avgs: List[
259
260
bias_correction1 = 1 - beta1 ** step
260
261
bias_correction2 = 1 - beta2 ** step
261
262
263
+ # mask out any params from the moment that point in the opposite direction of
264
+ # the grad
265
+ if mask_on :
266
+ update = exp_avg .sign ().mul_ (grad .sign ()).sign_ ().clamp_ (0 )
267
+ update .mul_ (exp_avg ).mul_ (beta1 ).add_ (grad , alpha = 1 - beta1 )
268
+ else :
269
+ update = exp_avg .mul (beta1 ).add_ (grad , alpha = 1 - beta1 )
270
+
262
271
# Decay the first and second moment running average coefficient
263
- exp_avg .mul_ (beta1 ).add_ (grad , alpha = 1 - beta1 )
264
272
exp_avg_sq .mul_ (beta2 ).addcmul_ (grad , grad , value = 1 - beta2 )
265
273
if amsgrad :
266
274
# Maintains the maximum of all 2nd moment running avg. till now
267
275
torch .maximum (max_exp_avg_sqs [i ], exp_avg_sq , out = max_exp_avg_sqs [i ])
268
276
# Use the max. for normalizing running avg. of gradient
269
- denom = ( max_exp_avg_sqs [i ].sqrt () / math .sqrt (bias_correction2 )).add_ (eps )
277
+ update . div_ (( max_exp_avg_sqs [i ].sqrt () / math .sqrt (bias_correction2 )).add_ (eps ) )
270
278
else :
271
- denom = ( exp_avg_sq .sqrt () / math .sqrt (bias_correction2 )).add_ (eps )
279
+ update . div_ (( exp_avg_sq .sqrt () / math .sqrt (bias_correction2 )).add_ (eps ) )
272
280
273
281
step_size = lr / bias_correction1
274
282
275
- param .addcdiv_ (exp_avg , denom , value = - step_size )
283
+ param .add_ (update , alpha = - step_size )
284
+
285
+ exp_avg .mul_ (beta1 ).add_ (grad , alpha = 1 - beta1 )
286
+
287
+ def turn_on_masking (self , param ):
288
+ self .state [param ]['mask' ] = True
276
289
277
290
@torch .no_grad ()
278
291
def step (self , closure = None ):
@@ -292,6 +305,7 @@ def step(self, closure=None):
292
305
grads = []
293
306
exp_avgs = []
294
307
exp_avg_sqs = []
308
+ masks_on = []
295
309
max_exp_avg_sqs = []
296
310
state_steps = []
297
311
amsgrad = group ['amsgrad' ]
@@ -326,6 +340,7 @@ def step(self, closure=None):
326
340
327
341
exp_avgs .append (state ['exp_avg' ])
328
342
exp_avg_sqs .append (state ['exp_avg_sq' ])
343
+ masks_on .append ('mask' in state and state ['mask' ])
329
344
if amsgrad :
330
345
max_exp_avg_sqs .append (state ['max_exp_avg_sq' ])
331
346
@@ -340,6 +355,7 @@ def step(self, closure=None):
340
355
exp_avg_sqs ,
341
356
max_exp_avg_sqs ,
342
357
state_steps ,
358
+ masks_on ,
343
359
amsgrad = amsgrad ,
344
360
beta1 = beta1 ,
345
361
beta2 = beta2 ,
@@ -414,7 +430,12 @@ def report_per_parameter_metrics(self, param: torch.Tensor, name: str, optimizer
414
430
bias_correction2 = 1 - beta2 ** step
415
431
denom = (param_optim_state ['exp_avg_sq' ].sqrt () / math .sqrt (bias_correction2 )).add_ (eps )
416
432
step_size = lr / bias_correction1
417
- step_tensor = step_size * param_optim_state ['exp_avg' ].div (denom )
433
+ if 'mask' in param_optim_state and param_optim_state ['mask' ]:
434
+ step_tensor = param_optim_state ['exp_avg' ].sign ().mul_ (param .grad .sign ()).sign_ ().clamp_ (0 )
435
+ step_tensor .mul_ (param_optim_state ['exp_avg' ]).mul_ (beta1 ).add_ (param .grad , alpha = 1 - beta1 )
436
+ step_tensor = step_size * step_tensor .div (denom )
437
+ else :
438
+ step_tensor = step_size * param_optim_state ['exp_avg' ].div (denom )
418
439
decay_factor = (lr / initial_lr ) if initial_lr else 1.0
419
440
step_tensor .add_ (param , alpha = - weight_decay * decay_factor )
420
441
for metric in self .metric_functions :
0 commit comments