|
| 1 | +import math |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch.optim.optimizer import Optimizer |
| 5 | + |
| 6 | +from pytorch_optimizer.types import ( |
| 7 | + BETAS, |
| 8 | + CLOSURE, |
| 9 | + DEFAULT_PARAMETERS, |
| 10 | + LOSS, |
| 11 | + PARAMS, |
| 12 | + STATE, |
| 13 | +) |
| 14 | + |
| 15 | + |
| 16 | +class AdaBelief(Optimizer): |
| 17 | + """ |
| 18 | + Reference : https://github.com/juntang-zhuang/Adabelief-Optimizer |
| 19 | + Example : |
| 20 | + from pytorch_optimizer import AdaBelief |
| 21 | + ... |
| 22 | + model = YourModel() |
| 23 | + optimizer = AdaBelief(model.parameters()) |
| 24 | + ... |
| 25 | + for input, output in data: |
| 26 | + optimizer.zero_grad() |
| 27 | + loss = loss_function(output, model(input)) |
| 28 | + loss.backward() |
| 29 | + optimizer.step() |
| 30 | + """ |
| 31 | + |
| 32 | + def __init__( |
| 33 | + self, |
| 34 | + params: PARAMS, |
| 35 | + lr: float = 1e-3, |
| 36 | + betas: BETAS = (0.9, 0.999), |
| 37 | + eps: float = 1e-16, |
| 38 | + weight_decay: float = 0.0, |
| 39 | + n_sma_threshold: int = 5, |
| 40 | + amsgrad: bool = False, |
| 41 | + weight_decouple: bool = True, |
| 42 | + fixed_decay: bool = False, |
| 43 | + rectify: bool = True, |
| 44 | + degenerated_to_sgd: bool = True, |
| 45 | + ): |
| 46 | + """AdaBelief optimizer |
| 47 | + :param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups |
| 48 | + :param lr: float. learning rate |
| 49 | + :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace |
| 50 | + :param eps: float. term added to the denominator to improve numerical stability |
| 51 | + :param weight_decay: float. weight decay (L2 penalty) |
| 52 | + :param n_sma_threshold: (recommended is 5) |
| 53 | + :param amsgrad: bool. whether to use the AMSBound variant |
| 54 | + :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW |
| 55 | + :param fixed_decay: bool. |
| 56 | + :param rectify: bool. perform the rectified update similar to RAdam |
| 57 | + :param degenerated_to_sgd: bool. perform SGD update when variance of gradient is high |
| 58 | + """ |
| 59 | + self.lr = lr |
| 60 | + self.betas = betas |
| 61 | + self.eps = eps |
| 62 | + self.weight_decay = weight_decay |
| 63 | + self.n_sma_threshold = n_sma_threshold |
| 64 | + self.degenerated_to_sgd = degenerated_to_sgd |
| 65 | + self.weight_decouple = weight_decouple |
| 66 | + self.rectify = rectify |
| 67 | + self.fixed_decay = fixed_decay |
| 68 | + self.degenerated_to_sgd = degenerated_to_sgd |
| 69 | + |
| 70 | + if ( |
| 71 | + isinstance(params, (list, tuple)) |
| 72 | + and len(params) > 0 |
| 73 | + and isinstance(params[0], dict) |
| 74 | + ): |
| 75 | + for param in params: |
| 76 | + if 'betas' in param and ( |
| 77 | + param['betas'][0] != betas[0] |
| 78 | + or param['betas'][1] != betas[1] |
| 79 | + ): |
| 80 | + param['buffer'] = [[None, None, None] for _ in range(10)] |
| 81 | + |
| 82 | + defaults: DEFAULT_PARAMETERS = dict( |
| 83 | + lr=lr, |
| 84 | + betas=betas, |
| 85 | + eps=eps, |
| 86 | + weight_decay=weight_decay, |
| 87 | + amsgrad=amsgrad, |
| 88 | + buffer=[[None, None, None] for _ in range(10)], |
| 89 | + ) |
| 90 | + super().__init__(params, defaults) |
| 91 | + |
| 92 | + def __setstate__(self, state: STATE): |
| 93 | + super().__setstate__(state) |
| 94 | + for group in self.param_groups: |
| 95 | + group.setdefault('amsgrad', False) |
| 96 | + |
| 97 | + def reset(self): |
| 98 | + for group in self.param_groups: |
| 99 | + for p in group['params']: |
| 100 | + state = self.state[p] |
| 101 | + amsgrad = group['amsgrad'] |
| 102 | + |
| 103 | + state['step'] = 0 |
| 104 | + state['exp_avg'] = torch.zeros_like(p.data) |
| 105 | + state['exp_avg_var'] = torch.zeros_like(p.data) |
| 106 | + if amsgrad: |
| 107 | + state['max_exp_avg_var'] = torch.zeros_like(p.data) |
| 108 | + |
| 109 | + def step(self, closure: CLOSURE = None) -> LOSS: |
| 110 | + loss: LOSS = None |
| 111 | + if closure is not None: |
| 112 | + loss = closure() |
| 113 | + |
| 114 | + for group in self.param_groups: |
| 115 | + for p in group['params']: |
| 116 | + if p.grad is None: |
| 117 | + continue |
| 118 | + |
| 119 | + half_precision: bool = False |
| 120 | + if p.data.dtype == torch.float16: |
| 121 | + half_precision = True |
| 122 | + p.data = p.data.float() |
| 123 | + p.grad = p.grad.float() |
| 124 | + |
| 125 | + grad = p.grad.data |
| 126 | + if grad.is_sparse: |
| 127 | + raise RuntimeError( |
| 128 | + 'AdaBelief does not support sparse gradients' |
| 129 | + ) |
| 130 | + |
| 131 | + amsgrad = group['amsgrad'] |
| 132 | + |
| 133 | + state = self.state[p] |
| 134 | + |
| 135 | + beta1, beta2 = group['betas'] |
| 136 | + |
| 137 | + if len(state) == 0: |
| 138 | + state['step'] = 0 |
| 139 | + state['exp_avg'] = torch.zeros_like(p.data) |
| 140 | + state['exp_avg_var'] = torch.zeros_like(p.data) |
| 141 | + if amsgrad: |
| 142 | + state['max_exp_avg_var'] = torch.zeros_like(p.data) |
| 143 | + |
| 144 | + if self.weight_decouple: |
| 145 | + if not self.fixed_decay: |
| 146 | + p.data.mul_(1.0 - group['lr'] * group['weight_decay']) |
| 147 | + else: |
| 148 | + p.data.mul_(1.0 - group['weight_decay']) |
| 149 | + else: |
| 150 | + if group['weight_decay'] != 0: |
| 151 | + grad.add_(p.data, alpha=group['weight_decay']) |
| 152 | + |
| 153 | + exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var'] |
| 154 | + |
| 155 | + state['step'] += 1 |
| 156 | + bias_correction1 = 1 - beta1 ** state['step'] |
| 157 | + bias_correction2 = 1 - beta2 ** state['step'] |
| 158 | + |
| 159 | + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) |
| 160 | + grad_residual = grad - exp_avg |
| 161 | + exp_avg_var.mul_(beta2).addcmul_( |
| 162 | + grad_residual, grad_residual, value=1 - beta2 |
| 163 | + ) |
| 164 | + |
| 165 | + if amsgrad: |
| 166 | + max_exp_avg_var = state['max_exp_avg_var'] |
| 167 | + |
| 168 | + torch.max( |
| 169 | + max_exp_avg_var, |
| 170 | + exp_avg_var.add_(group['eps']), |
| 171 | + out=max_exp_avg_var, |
| 172 | + ) |
| 173 | + |
| 174 | + denom = ( |
| 175 | + max_exp_avg_var.sqrt() / math.sqrt(bias_correction2) |
| 176 | + ).add_(group['eps']) |
| 177 | + else: |
| 178 | + denom = ( |
| 179 | + exp_avg_var.add_(group['eps']).sqrt() |
| 180 | + / math.sqrt(bias_correction2) |
| 181 | + ).add_(group['eps']) |
| 182 | + |
| 183 | + if not self.rectify: |
| 184 | + step_size = group['lr'] / bias_correction1 |
| 185 | + p.data.addcdiv_(exp_avg, denom, value=-step_size) |
| 186 | + else: |
| 187 | + buffered = group['buffer'][int(state['step'] % 10)] |
| 188 | + if state['step'] == buffered[0]: |
| 189 | + n_sma, step_size = buffered[1], buffered[2] |
| 190 | + else: |
| 191 | + buffered[0] = state['step'] |
| 192 | + beta2_t = beta2 ** state['step'] |
| 193 | + n_sma_max = 2 / (1 - beta2) - 1 |
| 194 | + n_sma = n_sma_max - 2 * state['step'] * beta2_t / ( |
| 195 | + 1 - beta2_t |
| 196 | + ) |
| 197 | + buffered[1] = n_sma |
| 198 | + |
| 199 | + if n_sma >= self.n_sma_threshold: |
| 200 | + step_size = math.sqrt( |
| 201 | + (1 - beta2_t) |
| 202 | + * (n_sma - 4) |
| 203 | + / (n_sma_max - 4) |
| 204 | + * (n_sma - 2) |
| 205 | + / n_sma |
| 206 | + * n_sma_max |
| 207 | + / (n_sma_max - 2) |
| 208 | + ) / (1 - beta1 ** state['step']) |
| 209 | + elif self.degenerated_to_sgd: |
| 210 | + step_size = 1.0 / (1 - beta1 ** state['step']) |
| 211 | + else: |
| 212 | + step_size = -1 |
| 213 | + buffered[2] = step_size |
| 214 | + |
| 215 | + if n_sma >= self.n_sma_threshold: |
| 216 | + denom = exp_avg_var.sqrt().add_(group['eps']) |
| 217 | + p.data.addcdiv_( |
| 218 | + exp_avg, denom, value=-step_size * group['lr'] |
| 219 | + ) |
| 220 | + elif step_size > 0: |
| 221 | + p.data.add_(exp_avg, alpha=-step_size * group['lr']) |
| 222 | + |
| 223 | + if half_precision: |
| 224 | + p.data = p.data.half() |
| 225 | + p.grad = p.grad.half() |
| 226 | + |
| 227 | + return loss |
0 commit comments