Skip to content

Commit e471ca5

Browse files
mvpatel2000bmosaicml
authored andcommitted
Speed monitor refactor (mosaicml#1987)
* add speed monitor refactor * fix docs * fix tests * fix remove 1 * extend test * format * respond to comments * restore caching * add deepcopy * add comment
1 parent e8fb131 commit e471ca5

File tree

1 file changed

+29
-8
lines changed

1 file changed

+29
-8
lines changed

composer/optim/decoupled_weight_decay.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,9 @@ def __init__(self,
225225

226226
@staticmethod
227227
def adamw(params: List[torch.Tensor], grads: List[torch.Tensor], exp_avgs: List[torch.Tensor],
228-
exp_avg_sqs: List[torch.Tensor], max_exp_avg_sqs: List[torch.Tensor], state_steps: List[int], *,
229-
amsgrad: bool, beta1: float, beta2: float, lr: float, initial_lr: float, weight_decay: float,
230-
eps: float) -> None:
228+
exp_avg_sqs: List[torch.Tensor], max_exp_avg_sqs: List[torch.Tensor], state_steps: List[int],
229+
masks_on: List[bool], *, amsgrad: bool, beta1: float, beta2: float, lr: float, initial_lr: float,
230+
weight_decay: float, eps: float) -> None:
231231
r"""Functional API that performs AdamW algorithm computation with decoupled weight decay.
232232
233233
Args:
@@ -250,6 +250,7 @@ def adamw(params: List[torch.Tensor], grads: List[torch.Tensor], exp_avgs: List[
250250
exp_avg = exp_avgs[i]
251251
exp_avg_sq = exp_avg_sqs[i]
252252
step = state_steps[i]
253+
mask_on = masks_on[i]
253254

254255
# Perform stepweight decay
255256
if weight_decay != 0:
@@ -259,20 +260,32 @@ def adamw(params: List[torch.Tensor], grads: List[torch.Tensor], exp_avgs: List[
259260
bias_correction1 = 1 - beta1**step
260261
bias_correction2 = 1 - beta2**step
261262

263+
# mask out any params from the moment that point in the opposite direction of
264+
# the grad
265+
if mask_on:
266+
update = exp_avg.sign().mul_(grad.sign()).sign_().clamp_(0)
267+
update.mul_(exp_avg).mul_(beta1).add_(grad, alpha=1 - beta1)
268+
else:
269+
update = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1)
270+
262271
# Decay the first and second moment running average coefficient
263-
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
264272
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
265273
if amsgrad:
266274
# Maintains the maximum of all 2nd moment running avg. till now
267275
torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
268276
# Use the max. for normalizing running avg. of gradient
269-
denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps)
277+
update.div_((max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps))
270278
else:
271-
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
279+
update.div_((exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps))
272280

273281
step_size = lr / bias_correction1
274282

275-
param.addcdiv_(exp_avg, denom, value=-step_size)
283+
param.add_(update, alpha=-step_size)
284+
285+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
286+
287+
def turn_on_masking(self, param):
288+
self.state[param]['mask'] = True
276289

277290
@torch.no_grad()
278291
def step(self, closure=None):
@@ -292,6 +305,7 @@ def step(self, closure=None):
292305
grads = []
293306
exp_avgs = []
294307
exp_avg_sqs = []
308+
masks_on = []
295309
max_exp_avg_sqs = []
296310
state_steps = []
297311
amsgrad = group['amsgrad']
@@ -326,6 +340,7 @@ def step(self, closure=None):
326340

327341
exp_avgs.append(state['exp_avg'])
328342
exp_avg_sqs.append(state['exp_avg_sq'])
343+
masks_on.append('mask' in state and state['mask'])
329344
if amsgrad:
330345
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
331346

@@ -340,6 +355,7 @@ def step(self, closure=None):
340355
exp_avg_sqs,
341356
max_exp_avg_sqs,
342357
state_steps,
358+
masks_on,
343359
amsgrad=amsgrad,
344360
beta1=beta1,
345361
beta2=beta2,
@@ -414,7 +430,12 @@ def report_per_parameter_metrics(self, param: torch.Tensor, name: str, optimizer
414430
bias_correction2 = 1 - beta2**step
415431
denom = (param_optim_state['exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)).add_(eps)
416432
step_size = lr / bias_correction1
417-
step_tensor = step_size * param_optim_state['exp_avg'].div(denom)
433+
if 'mask' in param_optim_state and param_optim_state['mask']:
434+
step_tensor = param_optim_state['exp_avg'].sign().mul_(param.grad.sign()).sign_().clamp_(0)
435+
step_tensor.mul_(param_optim_state['exp_avg']).mul_(beta1).add_(param.grad, alpha=1 - beta1)
436+
step_tensor = step_size * step_tensor.div(denom)
437+
else:
438+
step_tensor = step_size * param_optim_state['exp_avg'].div(denom)
418439
decay_factor = (lr / initial_lr) if initial_lr else 1.0
419440
step_tensor.add_(param, alpha=-weight_decay * decay_factor)
420441
for metric in self.metric_functions:

0 commit comments

Comments
 (0)