|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved |
| 3 | + |
| 4 | + |
| 5 | +import torch |
| 6 | +from pytext.optimizer.optimizers import Optimizer |
| 7 | +from torch.optim import Optimizer as PT_Optimizer |
| 8 | + |
| 9 | + |
| 10 | +class Lamb(Optimizer, PT_Optimizer): |
| 11 | + r"""Implements Lamb algorithm. |
| 12 | + THIS WAS DIRECTLY COPIED OVER FROM pytorch/contrib |
| 13 | + https://github.com/cybertronai/pytorch-lamb |
| 14 | + It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. |
| 15 | +
|
| 16 | + Arguments: |
| 17 | + params (iterable): iterable of parameters to optimize or dicts defining |
| 18 | + parameter groups |
| 19 | + lr (float, optional): learning rate (default: 1e-3) |
| 20 | + betas (Tuple[float, float], optional): coefficients used for computing |
| 21 | + running averages of gradient and its square (default: (0.9, 0.999)) |
| 22 | + eps (float, optional): term added to the denominator to improve |
| 23 | + numerical stability (default: 1e-8) |
| 24 | + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) |
| 25 | + .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: |
| 26 | + https://arxiv.org/abs/1904.00962 |
| 27 | + """ |
| 28 | + |
| 29 | + class Config(Optimizer.Config): |
| 30 | + lr: float = 0.001 |
| 31 | + weight_decay: float = 0.00001 |
| 32 | + eps: float = 1e-8 |
| 33 | + |
| 34 | + @classmethod |
| 35 | + def from_config(cls, config: Config, model: torch.nn.Module): |
| 36 | + return cls( |
| 37 | + model.parameters(), |
| 38 | + lr=config.lr, |
| 39 | + weight_decay=config.weight_decay, |
| 40 | + eps=config.eps, |
| 41 | + ) |
| 42 | + |
| 43 | + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0): |
| 44 | + if not 0.0 <= lr: |
| 45 | + raise ValueError("Invalid learning rate: {}".format(lr)) |
| 46 | + if not 0.0 <= eps: |
| 47 | + raise ValueError("Invalid epsilon value: {}".format(eps)) |
| 48 | + if not 0.0 <= betas[0] < 1.0: |
| 49 | + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) |
| 50 | + if not 0.0 <= betas[1] < 1.0: |
| 51 | + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) |
| 52 | + PT_Optimizer.__init__( |
| 53 | + self, |
| 54 | + params, |
| 55 | + {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay}, |
| 56 | + ) |
| 57 | + |
| 58 | + def step(self, closure=None): |
| 59 | + """Performs a single optimization step. |
| 60 | +
|
| 61 | + Arguments: |
| 62 | + closure (callable, optional): A closure that reevaluates the model |
| 63 | + and returns the loss. |
| 64 | + """ |
| 65 | + loss = None |
| 66 | + if closure is not None: |
| 67 | + loss = closure() |
| 68 | + |
| 69 | + for group in self.param_groups: |
| 70 | + for p in group["params"]: |
| 71 | + if p.grad is None: |
| 72 | + continue |
| 73 | + grad = p.grad.data |
| 74 | + if grad.is_sparse: |
| 75 | + raise RuntimeError( |
| 76 | + "Lamb does not support sparse gradients, consider SparseAdam instad." |
| 77 | + ) |
| 78 | + |
| 79 | + state = self.state[p] |
| 80 | + |
| 81 | + # State initialization |
| 82 | + if len(state) == 0: |
| 83 | + state["step"] = 0 |
| 84 | + # Exponential moving average of gradient values |
| 85 | + state["exp_avg"] = torch.zeros_like(p.data) |
| 86 | + # Exponential moving average of squared gradient values |
| 87 | + state["exp_avg_sq"] = torch.zeros_like(p.data) |
| 88 | + |
| 89 | + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] |
| 90 | + beta1, beta2 = group["betas"] |
| 91 | + |
| 92 | + state["step"] += 1 |
| 93 | + |
| 94 | + # Decay the first and second moment running average coefficient |
| 95 | + # m_t |
| 96 | + exp_avg.mul_(beta1).add_(1 - beta1, grad) |
| 97 | + # v_t |
| 98 | + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) |
| 99 | + |
| 100 | + # Paper v3 does not use debiasing. |
| 101 | + # bias_correction1 = 1 - beta1 ** state['step'] |
| 102 | + # bias_correction2 = 1 - beta2 ** state['step'] |
| 103 | + # Apply bias to lr to avoid broadcast. |
| 104 | + step_size = group["lr"] |
| 105 | + # * math.sqrt(bias_correction2) / bias_correction1 |
| 106 | + |
| 107 | + weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) |
| 108 | + |
| 109 | + adam_step = exp_avg / exp_avg_sq.sqrt().add(group["eps"]) |
| 110 | + if group["weight_decay"] != 0: |
| 111 | + adam_step.add_(group["weight_decay"], p.data) |
| 112 | + |
| 113 | + adam_norm = adam_step.pow(2).sum().sqrt() |
| 114 | + if weight_norm == 0 or adam_norm == 0: |
| 115 | + trust_ratio = 1 |
| 116 | + else: |
| 117 | + trust_ratio = weight_norm / adam_norm |
| 118 | + state["weight_norm"] = weight_norm |
| 119 | + state["adam_norm"] = adam_norm |
| 120 | + state["trust_ratio"] = trust_ratio |
| 121 | + p.data.add_(-step_size * trust_ratio, adam_step) |
| 122 | + |
| 123 | + return loss |
0 commit comments