Skip to content

Commit f89fb3e

Browse files
committed
add layerwise lr
1 parent dc054df commit f89fb3e

File tree

6 files changed

+315
-14
lines changed

6 files changed

+315
-14
lines changed

composer/callbacks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from composer.callbacks.optimizer_monitor import OptimizerMonitor
1717
from composer.callbacks.speed_monitor import SpeedMonitor
1818
from composer.callbacks.threshold_stopper import ThresholdStopper
19+
from composer.callbacks.loss_spike_intervention import LossSpikeIntervention
1920

2021
__all__ = [
2122
'OptimizerMonitor',
@@ -27,5 +28,6 @@
2728
'EarlyStopper',
2829
'ExportForInferenceCallback',
2930
'ThresholdStopper',
31+
'LossSpikeIntervention',
3032
'ImageVisualizer',
3133
]
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
# Copyright 2022 MosaicML Composer authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Monitor gradients during training."""
5+
6+
import torch
7+
8+
from composer.core import Callback, State
9+
from composer.loggers import Logger
10+
from composer.utils import dist
11+
import collections
12+
13+
__all__ = ['LossSpikeIntervention']
14+
15+
16+
17+
class MetricSpikeDetector:
18+
19+
def __init__(self,
20+
window_moving_average=25,
21+
increase_factor=5,
22+
increase_lookback=500,
23+
plateau_min_duration=100,
24+
end_spike_factor=1.10):
25+
26+
self.window_moving_average=window_moving_average
27+
self.increase_factor=increase_factor
28+
self.plateau_min_duration=plateau_min_duration
29+
self.increase_lookback = increase_lookback
30+
self.fast_moving_average = collections.deque(maxlen=window_moving_average)
31+
self.intermediate_data_queue = collections.deque(maxlen=increase_lookback-window_moving_average)
32+
self.slow_moving_average = collections.deque(maxlen=increase_lookback)
33+
self.end_spike_factor = end_spike_factor
34+
self.in_spike = False
35+
self.mva_before_spike = None
36+
self.spike_batch_idx_start = None
37+
38+
39+
40+
def insert_observation(self, obs, batch_idx):
41+
if len(self.fast_moving_average) >= self.fast_moving_average.maxlen:
42+
# move the oldest obs out of the fast moving average into the
43+
# intermediate data queue
44+
fast_obs = self.fast_moving_average.popleft()
45+
46+
if len(self.intermediate_data_queue) >= self.intermediate_data_queue.maxlen:
47+
# move data from intermediate quque to slow MCVA queue
48+
intermediate_obs = self.intermediate_data_queue.popleft()
49+
self.slow_moving_average.append(intermediate_obs)
50+
51+
self.intermediate_data_queue.append(fast_obs)
52+
53+
self.fast_moving_average.append(obs)
54+
55+
fast_mva = sum(self.fast_moving_average) / len(self.fast_moving_average)
56+
if not self.in_spike:
57+
if len(self.slow_moving_average) > self.window_moving_average:
58+
if self.mva_before_spike is None:
59+
slow_mva = sum(self.slow_moving_average) / len(self.slow_moving_average)
60+
else:
61+
slow_mva = self.mva_before_spike
62+
63+
64+
if fast_mva >= self.increase_factor * slow_mva:
65+
self.in_spike = True
66+
self.mva_before_spike = slow_mva
67+
self.spike_batch_idx_start = batch_idx
68+
else:
69+
if batch_idx - self.spike_batch_idx_start > self.plateau_min_duration:
70+
# kill the layer!
71+
return True
72+
else:
73+
if fast_mva <= self.mva_before_spike * self.end_spike_factor:
74+
self.in_spike = False
75+
self.spike_batch_idx_start = None
76+
77+
return False
78+
79+
80+
81+
class LossSpikeIntervention(Callback):
82+
83+
def __init__(self,
84+
metric = 'l2_norm/moment',
85+
window_moving_average=25,
86+
increase_factor=5,
87+
increase_lookback=500,
88+
plateau_min_duration=100,
89+
end_spike_factor=1.10,
90+
lr_scale=0.0
91+
):
92+
self.metric = metric
93+
self.lr_scale = lr_scale
94+
self.window_moving_average = window_moving_average
95+
self.increase_factor = increase_factor
96+
self.increase_lookback = increase_lookback
97+
self.plateau_min_duration = plateau_min_duration
98+
self.end_spike_factor = end_spike_factor
99+
100+
self.metric_spike_detectors = {}
101+
self.frozen_layers = set()
102+
self.all_layers = set()
103+
104+
def fit_start(self, state: State, logger: Logger) -> None:
105+
for name, p in state.model.named_parameters():
106+
if p.requires_grad:
107+
self.all_layers.add(name)
108+
full_metric_name = f"{self.metric}/{name}"
109+
self.metric_spike_detectors[full_metric_name] = MetricSpikeDetector(
110+
self.window_moving_average,
111+
self.increase_factor,
112+
self.increase_lookback,
113+
self.plateau_min_duration,
114+
self.end_spike_factor,
115+
)
116+
117+
def batch_end(self, state: State, logger: Logger):
118+
norm = 0.0
119+
optimizer_metrics = {}
120+
121+
for name, p in state.model.named_parameters():
122+
if p.grad is not None and p.requires_grad:
123+
124+
metric_reporter = getattr(state.optimizers[0], 'report_per_parameter_metrics', None)
125+
if callable(metric_reporter):
126+
optimizer_metrics = metric_reporter(p, name, optimizer_metrics)
127+
128+
if f'l2_norm/grad/{name}' not in optimizer_metrics:
129+
param_grad_norm = torch.linalg.vector_norm(p.grad)
130+
optimizer_metrics[f'l2_norm/grad/{name}'] = param_grad_norm
131+
132+
if state.fsdp_enabled and dist.get_world_size() > 0 :
133+
pre_reduce_metrics = getattr(state.optimizers[0], 'pre_reduce_metrics', None)
134+
if callable(pre_reduce_metrics):
135+
optimizer_metrics = pre_reduce_metrics(optimizer_metrics)
136+
137+
dist_reduce_metrics = getattr(state.optimizers[0], 'dist_reduce_metrics', None)
138+
if callable(dist_reduce_metrics):
139+
optimizer_metrics = dist_reduce_metrics(optimizer_metrics)
140+
141+
for metric in optimizer_metrics:
142+
if metric.startswith('l2_norm/grad'):
143+
norm += optimizer_metrics[metric]**2
144+
145+
optimizer_metrics['l2_norm/grad/global'] = norm**0.5
146+
147+
for metric in optimizer_metrics:
148+
if isinstance(optimizer_metrics[metric], torch.Tensor):
149+
optimizer_metrics[metric] = optimizer_metrics[metric].item()
150+
151+
batch_idx = state.timestamp.batch.value
152+
newly_failed_layers = self.detect_failed_layers(optimizer_metrics, batch_idx)
153+
154+
if len(newly_failed_layers) > 0:
155+
self.freeze_layers(newly_failed_layers, state)
156+
for optimizer in state.optimizers:
157+
for group in optimizer.param_groups:
158+
group['lr'] *= self.lr_scale
159+
160+
for scheduler in state.schedulers:
161+
scheduler.base_lrs = [self.lr_scale * lr for lr in scheduler.base_lrs]
162+
163+
164+
optimizer_metrics['num_frozen_layers'] = len(self.frozen_layers)
165+
logger.log_metrics(optimizer_metrics)
166+
167+
if len(self.all_layers) == 0:
168+
state.stop_training()
169+
170+
171+
def freeze_layers(self, newly_failed_layers, state):
172+
for layer in newly_failed_layers:
173+
self.all_layers.remove(layer)
174+
if layer not in self.frozen_layers:
175+
self.frozen_layers.add(layer)
176+
177+
for name, p in state.model.named_parameters():
178+
if name in self.frozen_layers:
179+
p.requires_grad = False
180+
181+
182+
def detect_failed_layers(self, optimizer_metrics, batch_idx):
183+
newly_failed = []
184+
for logger_name, value in optimizer_metrics.items():
185+
if logger_name.startswith(self.metric):
186+
layer_name = logger_name.split('/')[-1]
187+
if layer_name in self.frozen_layers:
188+
continue
189+
if self.metric_spike_detectors[logger_name].insert_observation(value, batch_idx):
190+
newly_failed.append(layer_name)
191+
192+
return newly_failed

composer/optim/decoupled_weight_decay.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -232,12 +232,22 @@ def __init__(self,
232232
super().__init__(params=params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
233233
for group in self.param_groups:
234234
group['initial_lr'] = group['lr']
235+
self.layer_to_scale = None
236+
237+
def get_scaling(self, param):
238+
if self.layer_to_scale:
239+
if param not in self.layer_to_scale:
240+
raise Exception(f"Couldn't find param: {param} in layer to scale: {self.layer_to_scale}")
241+
else:
242+
return self.layer_to_scale[param]
243+
else:
244+
return 1.0
235245

236246
@staticmethod
237247
def adamw(params: List[torch.Tensor], grads: List[torch.Tensor], exp_avgs: List[torch.Tensor],
238248
exp_avg_sqs: List[torch.Tensor], max_exp_avg_sqs: List[torch.Tensor], state_steps: List[int], *,
239-
amsgrad: bool, beta1: float, beta2: float, lr: float, initial_lr: float, weight_decay: float,
240-
eps: float) -> None:
249+
amsgrad: bool, beta1: float, beta2: float, lr: float, initial_lr: float, weight_decay: float, eps: float,
250+
layerwise_lrs) -> None:
241251
r"""Functional API that performs AdamW algorithm computation with decoupled weight decay.
242252
243253
Args:
@@ -280,10 +290,26 @@ def adamw(params: List[torch.Tensor], grads: List[torch.Tensor], exp_avgs: List[
280290
else:
281291
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
282292

283-
step_size = lr / bias_correction1
293+
step_size = lr * layerwise_lrs[i] / bias_correction1
284294

285295
param.addcdiv_(exp_avg, denom, value=-step_size)
286296

297+
def reset_state(self):
298+
for group in self.param_groups:
299+
amsgrad = group['amsgrad']
300+
for p in group['params']:
301+
if not p.requires_grad:
302+
continue
303+
state = self.state[p]
304+
state['step'] = 0
305+
# Exponential moving average of gradient values
306+
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
307+
# Exponential moving average of squared gradient values
308+
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
309+
if amsgrad:
310+
# Maintains max of all exp. moving avg. of sq. grad. values
311+
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
312+
287313
@torch.no_grad()
288314
def step(self, closure=None):
289315
"""Performs a single optimization step.
@@ -304,6 +330,7 @@ def step(self, closure=None):
304330
exp_avg_sqs = []
305331
max_exp_avg_sqs = []
306332
state_steps = []
333+
layerwise_lrs = []
307334
amsgrad = group['amsgrad']
308335
beta1, beta2 = group['betas']
309336
eps = group['eps']
@@ -312,7 +339,7 @@ def step(self, closure=None):
312339
weight_decay = group['weight_decay']
313340

314341
for p in group['params']:
315-
if p.grad is None:
342+
if p.grad is None or not p.requires_grad:
316343
continue
317344
params_with_grad.append(p)
318345
if p.grad.is_sparse:
@@ -322,7 +349,7 @@ def step(self, closure=None):
322349
state = self.state[p]
323350

324351
# State initialization
325-
if len(state) == 0:
352+
if 'step' not in state:
326353
state['step'] = 0
327354
# Exponential moving average of gradient values
328355
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
@@ -334,7 +361,7 @@ def step(self, closure=None):
334361

335362
exp_avgs.append(state['exp_avg'])
336363
exp_avg_sqs.append(state['exp_avg_sq'])
337-
364+
layerwise_lrs.append(self.get_scaling(p))
338365
if amsgrad:
339366
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
340367

@@ -355,13 +382,16 @@ def step(self, closure=None):
355382
lr=lr,
356383
initial_lr=initial_lr,
357384
weight_decay=weight_decay,
358-
eps=eps)
385+
eps=eps,
386+
layerwise_lrs=layerwise_lrs)
359387

360388
return loss
361389

362390
def dist_reduce_metrics(self, optimizer_metrics):
363391
for metric in optimizer_metrics:
364-
if metric.startswith('l2_norm'):
392+
if metric.startswith('layerwise_lr_scaling'):
393+
continue
394+
elif metric.startswith('l2_norm'):
365395
reduced = optimizer_metrics[metric]
366396
if dist.get_world_size() > 1:
367397
dist.all_reduce(reduced, reduce_operation='SUM')
@@ -385,15 +415,17 @@ def dist_reduce_metrics(self, optimizer_metrics):
385415
if dist.get_world_size() > 1:
386416
dist.all_reduce(reduced, reduce_operation='SUM')
387417
optimizer_metrics[metric] = reduced / dist.get_world_size()
388-
418+
389419
return optimizer_metrics
390420

391421
def pre_reduce_metrics(self, optimizer_metrics):
392422
# some of the metrics need to be modified before being reduced in order for the
393423
# reduction to work properly
394424

395425
for metric in optimizer_metrics:
396-
if metric.startswith('l2_norm'):
426+
if metric.startswith('layerwise_lr_scaling'):
427+
continue
428+
elif metric.startswith('l2_norm'):
397429
# l2 norms need to be squared, before they are reduced via summation
398430
optimizer_metrics[metric] = optimizer_metrics[metric]**2
399431
elif metric.startswith('cosine'):
@@ -418,16 +450,19 @@ def report_per_parameter_metrics(self, param: torch.Tensor, name: str, optimizer
418450
beta1, beta2 = self.param_groups[0]['betas']
419451
if param in self.state:
420452
param_optim_state = self.state[param]
453+
local_lr = lr * self.get_scaling(param)
421454
step = param_optim_state['step']
422455
bias_correction1 = 1 - beta1**step
423456
bias_correction2 = 1 - beta2**step
424457
denom = (param_optim_state['exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)).add_(eps)
425-
step_size = lr / bias_correction1
458+
step_size = local_lr / bias_correction1
426459
step_tensor = step_size * param_optim_state['exp_avg'].div(denom)
427-
decay_factor = (lr / initial_lr) if initial_lr else 1.0
460+
decay_factor = (local_lr / initial_lr) if initial_lr else 1.0
428461
step_tensor.add_(param, alpha=-weight_decay * decay_factor)
429462
for metric in self.metric_functions:
430463
optimizer_metrics[f'{metric}/{name}'] = self.metric_functions[metric](param, param_optim_state,
431464
step_tensor)
432465

466+
optimizer_metrics[f'layerwise_lr_scaling/{name}'] = self.get_scaling(param)
467+
433468
return optimizer_metrics

0 commit comments

Comments
 (0)