Skip to content

Commit ff3ad20

Browse files
authored
Speed monitor refactor (#1987)
* add speed monitor refactor * fix docs * fix tests * fix remove 1 * extend test * format * respond to comments * restore caching * add deepcopy * add comment
1 parent 0926cbf commit ff3ad20

File tree

4 files changed

+301
-66
lines changed

4 files changed

+301
-66
lines changed

composer/callbacks/speed_monitor.py

Lines changed: 218 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,144 @@
44
"""Monitor throughput during training."""
55
from __future__ import annotations
66

7+
import warnings
78
from collections import deque
8-
from typing import Any, Deque, Dict
9+
from typing import Any, Callable, Deque, Dict, Optional, Union
10+
11+
import torch
912

1013
from composer.core import Callback, State
1114
from composer.loggers import Logger
15+
from composer.models.base import ComposerModel
16+
from composer.utils import dist
1217

1318
__all__ = ['SpeedMonitor']
1419

20+
GPU_AVAILABLE_FLOPS = {
21+
# source: https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet
22+
# nvidia publishes spec sheet with a 2x sparsity factor
23+
'h100-sxm': {
24+
'fp64': 67e12,
25+
'fp32': 67e12,
26+
'tf32': 989e12 / 2,
27+
'fp16': 1.979e15 / 2,
28+
'amp_fp16': 1.979e15 / 2,
29+
'bf16': 1.979e15 / 2,
30+
'amp_bf16': 1.979e15 / 2,
31+
'fp8': 3.958e15 / 2,
32+
'amp_fp8': 3.958e15 / 2,
33+
'int8': 3.958e15 / 2,
34+
},
35+
'h100-pcie': {
36+
'fp64': 51e12,
37+
'fp32': 51e12,
38+
'tf32': 756e12 / 2,
39+
'fp16': 1.513e15 / 2,
40+
'amp_fp16': 1.513e15 / 2,
41+
'bf16': 1.513e15 / 2,
42+
'amp_bf16': 1.513e15 / 2,
43+
'fp8': 3.026e15 / 2,
44+
'amp_fp8': 3.026e15 / 2,
45+
'int8': 3.026e15 / 2,
46+
},
47+
# source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
48+
# sxm and pcie have same flop counts
49+
'a100': {
50+
'fp64': 19.5e12,
51+
'fp32': 19.5e12,
52+
'tf32': 156e12,
53+
'fp16': 312e12,
54+
'amp_fp16': 312e12,
55+
'bf16': 312e12,
56+
'amp_bf16': 312e12,
57+
},
58+
# source: https://images.nvidia.com/content/technologies/volta/pdf/volta-v100-datasheet-update-us-1165301-r5.pdf
59+
'v100-sxm': {
60+
'fp64': 7.8e12,
61+
'fp32': 15.7e12,
62+
'fp16': 125e12,
63+
'amp_fp16': 125e12,
64+
},
65+
'v100-pcie': {
66+
'fp64': 7e12,
67+
'fp32': 14e12,
68+
'fp16': 112e12,
69+
'amp_fp16': 112e12,
70+
},
71+
'v100s-pcie': {
72+
'fp64': 8.2e12,
73+
'fp32': 16.4e12,
74+
'fp16': 130e12,
75+
'amp_fp16': 130e12,
76+
},
77+
# source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf
78+
# sxm and pcie have same flop counts
79+
't4': {
80+
'fp32': 8.1e12,
81+
'fp16': 65e12,
82+
'amp_fp16': 65e12,
83+
'int8': 130e12,
84+
'int4': 260e12,
85+
},
86+
}
87+
88+
89+
def get_gpu_flops_available(state: State):
90+
gpu_flops_available = None
91+
92+
# Return 0 if no CUDA device (e.g., when running with CPU only)
93+
if not torch.cuda.is_available():
94+
return 0
95+
96+
# torch.cuda.get_device_name() ex output: 'NVIDIA A100-SXM4-40GB'
97+
device_name = torch.cuda.get_device_name().lower()
98+
if 'h100-sxm' in device_name:
99+
device_name = 'h100-sxm'
100+
elif 'h100-pcie' in device_name:
101+
device_name = 'h100-pcie'
102+
elif 'a100' in device_name:
103+
device_name = 'a100'
104+
elif 'v100-sxm' in device_name:
105+
device_name = 'v100-sxm'
106+
elif 'v100-pcie' in device_name:
107+
device_name = 'v100-pcie'
108+
elif 't4' in device_name:
109+
device_name = 't4'
110+
else:
111+
device_name = None
112+
113+
if device_name is not None:
114+
try:
115+
gpu_flops_available = int(GPU_AVAILABLE_FLOPS[device_name][state.precision.value])
116+
except:
117+
gpu_flops_available = None
118+
119+
if gpu_flops_available is None:
120+
warnings.warn(
121+
f'gpu_flop count not found for {device_name} with precision: {state.precision.value}; ' +\
122+
f'MFU cannot be calculated and reported. gpu_flops_available can be manually' +\
123+
f'overridden by setting gpu_flops_available in SpeedMonitor.'
124+
)
125+
# Setting to 0 will disable MFU computation and prevent
126+
# the speed monitor from running this helper every batch
127+
gpu_flops_available = 0
128+
129+
return gpu_flops_available
130+
15131

16132
class SpeedMonitor(Callback):
17133
"""Logs the training throughput.
18134
19-
The training throughput in terms of number of samples per second is logged on the
20-
:attr:`.Event.BATCH_END` event if we have reached the ``window_size`` threshold.
135+
The training throughput is logged on the :attr:`.Event.BATCH_END` event once we have reached
136+
the `window_size` threshold. If a model has `flops_per_batch` attribute, then flops per second
137+
is also logged. If running on a known GPU type or if `gpu_flops_available` is set, then MFU is
138+
also logged. All metrics are also logged as per device by dividing by world size.
21139
22-
The wall clock train time is logged on every :attr:`.Event.BATCH_END` event.
140+
To compute `flops_per_sec`, the model attribute `flops_per_batch` should be set to a callable
141+
which accepts a batch and returns the number of flops for that batch. Typically, this should
142+
be flops per sample times the batch size unless pad tokens are used.
23143
24-
The average throughout over an epoch is logged on the :attr:`.Event.EPOCH_END` event.
144+
The wall clock time is logged on every :attr:`.Event.BATCH_END` event.
25145
26146
Example:
27147
.. doctest::
@@ -41,84 +161,130 @@ class SpeedMonitor(Callback):
41161
The training throughput is logged by the :class:`.Logger` to the following keys as
42162
described below.
43163
44-
+----------------------------------+-------------------------------------------------------------+
45-
| Key | Logged data |
46-
+==================================+=============================================================+
47-
| | Rolling average (over ``window_size`` most recent |
48-
| ``throughput/samples_per_sec`` | batches) of the number of samples processed per second |
49-
| | |
50-
+----------------------------------+-------------------------------------------------------------+
51-
| ``wall_clock/train`` | Total elapsed training time |
52-
+----------------------------------+-------------------------------------------------------------+
53-
| ``wall_clock/val`` | Total elapsed validation time |
54-
+----------------------------------+-------------------------------------------------------------+
55-
| ``wall_clock/total`` | Total elapsed time (wall_clock/train + wall_clock/val) |
56-
+----------------------------------+-------------------------------------------------------------+
164+
+-------------------------------------+-----------------------------------------------------------+
165+
| Key | Logged data |
166+
+=====================================+===========================================================+
167+
| | Rolling average (over `window_size` most recent |
168+
| `throughput/batches_per_sec` | batches) of the number of batches processed per second |
169+
| | |
170+
+-------------------------------------+-----------------------------------------------------------+
171+
| | Rolling average (over `window_size` most recent |
172+
| `throughput/samples_per_sec` | batches) of the number of samples processed per second |
173+
| | |
174+
+-------------------------------------+-----------------------------------------------------------+
175+
| | Rolling average (over `window_size` most recent |
176+
| `throughput/tokens_per_sec` | batches) of the number of tokens processed per second. |
177+
| | Only logged when dataloader.dataset has `max_seq_len`. |
178+
| | This may include padding depending on dataset |
179+
+-------------------------------------+-----------------------------------------------------------+
180+
| | Estimates flops by `flops_per_batch * batches_per_sec` |
181+
| `throughput/flops_per_sec` | if model has attribute `flops_per_batch` |
182+
| | |
183+
+-------------------------------------+-----------------------------------------------------------+
184+
| `throughput/device/batches_per_sec` | `throughput/batches_per_sec` divided by world size |
185+
+-------------------------------------+-----------------------------------------------------------+
186+
| `throughput/device/samples_per_sec` | `throughput/samples_per_sec` divided by world size |
187+
+-------------------------------------+-----------------------------------------------------------+
188+
| | `throughput/tokens_per_sec` divided by world size. Only |
189+
| `throughput/device/tokens_per_sec` | logged when dataloader.dataset has `max_seq_len`. This |
190+
| | may include pad tokens depending on dataset |
191+
+-------------------------------------+-----------------------------------------------------------+
192+
| | `throughput/flops_per_sec` divided by world size. Only |
193+
| `throughput/device/flops_per_sec` | logged when model has attribute `flops_per_batch` |
194+
| | |
195+
+-------------------------------------+-----------------------------------------------------------+
196+
| | `throughput/device/flops_per_sec` divided by world size. |
197+
| `throughput/device/mfu` | Only logged when model has attribute `flops_per_batch` |
198+
| | and `gpu_flops_available`, which can be passed as an |
199+
| | argument if not automatically determined by SpeedMonitor |
200+
+-------------------------------------+-----------------------------------------------------------+
201+
| `wall_clock/train` | Total elapsed training time |
202+
+-------------------------------------+-----------------------------------------------------------+
203+
| `wall_clock/val` | Total elapsed validation time |
204+
+-------------------------------------+-----------------------------------------------------------+
205+
| `wall_clock/total` | Total elapsed time (wall_clock/train + wall_clock/val) |
206+
+-------------------------------------+-----------------------------------------------------------+
57207
58208
Args:
59209
window_size (int, optional): Number of batches to use for a rolling average of throughput.
60210
Defaults to 100.
61211
"""
62212

63-
def __init__(self, window_size: int = 100):
213+
def __init__(self, window_size: int = 100, gpu_flops_available: Optional[Union[float, int]] = None):
64214
# Track the batch num samples and wct to compute throughput over a window of batches
65-
self.batch_start_num_samples = 0
66-
self.batch_start_wct = 0.0
67-
self.batch_wct_buffer: Deque[float] = deque(maxlen=window_size)
68-
self.batch_num_samples_buffer: Deque[int] = deque(maxlen=window_size)
69-
self.window_size = window_size
215+
self.history_samples: Deque[int] = deque(maxlen=window_size + 1)
216+
self.history_wct: Deque[float] = deque(maxlen=window_size + 1)
217+
218+
self.gpu_flops_available = gpu_flops_available
70219

71220
# Keep track of time spent evaluating
72221
self.total_eval_wct = 0.0
73222

74223
def state_dict(self) -> Dict[str, Any]:
75224
return {
76-
'batch_start_num_samples': self.batch_start_num_samples,
77-
'batch_start_wct': self.batch_start_wct,
78-
'batch_wct_buffer': self.batch_wct_buffer,
79-
'batch_num_samples_buffer': self.batch_num_samples_buffer,
80-
# "window_wct": self.window_wct,
81-
# "window_num_samples": self.window_num_samples,
82225
'total_eval_wct': self.total_eval_wct,
83226
}
84227

85228
def load_state_dict(self, state: Dict[str, Any]) -> None:
86-
self.batch_start_num_samples = state['batch_start_num_samples']
87-
self.batch_start_wct = state['batch_start_wct']
88-
self.batch_wct_buffer = deque(
89-
[x for x in state['batch_wct_buffer']],
90-
maxlen=self.window_size,
91-
)
92-
self.batch_num_samples_buffer = deque(
93-
[x for x in state['batch_num_samples_buffer']],
94-
maxlen=self.window_size,
95-
)
96229
self.total_eval_wct = state['total_eval_wct']
97230

98-
def before_dataloader(self, state: State, logger: Logger) -> None:
231+
def init(self, state: State, logger: Logger) -> None:
99232
del logger # unused
100-
self.batch_start_wct = state.timestamp.total_wct.total_seconds()
101-
self.batch_start_num_samples = int(state.timestamp.sample)
233+
if self.gpu_flops_available is None:
234+
self.gpu_flops_available = get_gpu_flops_available(state)
102235

103236
def batch_end(self, state: State, logger: Logger):
104-
batch_num_samples = int(state.timestamp.sample) - self.batch_start_num_samples
105-
batch_wct = state.timestamp.total_wct.total_seconds() - self.batch_start_wct
106-
107237
# Add the new element
108-
self.batch_wct_buffer.append(batch_wct)
109-
self.batch_num_samples_buffer.append(batch_num_samples)
238+
self.history_samples.append(state.timestamp.sample.value)
239+
self.history_wct.append(state.timestamp.total_wct.total_seconds())
110240

111241
# Log the throughput
112-
if len(self.batch_num_samples_buffer) == self.window_size:
113-
throughput = sum(self.batch_num_samples_buffer) / sum(self.batch_wct_buffer)
114-
logger.log_metrics({'throughput/samples_per_sec': throughput})
242+
if len(self.history_wct) == self.history_wct.maxlen:
243+
world_size = dist.get_world_size()
244+
elapsed_batches = len(self.history_samples) - 1
245+
elapsed_samples = int(self.history_samples[-1]) - int(self.history_samples[0])
246+
elapsed_wct = self.history_wct[-1] - self.history_wct[0]
247+
batches_per_sec = elapsed_batches / elapsed_wct
248+
samples_per_sec = elapsed_samples / elapsed_wct
249+
dev_batches_per_sec = batches_per_sec / world_size
250+
dev_samples_per_sec = samples_per_sec / world_size
251+
logger.log_metrics({'throughput/batches_per_sec': batches_per_sec})
252+
logger.log_metrics({'throughput/samples_per_sec': samples_per_sec})
253+
logger.log_metrics({'throughput/device/batches_per_sec': dev_batches_per_sec})
254+
logger.log_metrics({'throughput/device/samples_per_sec': dev_samples_per_sec})
255+
256+
# Compute token stats if dataloader.dataset has max_seq_len. Assumes no padding.
257+
try:
258+
max_seq_len = state.dataloader.dataset.max_seq_len # type: ignore
259+
# Only applicable to seq data / models
260+
logger.log_metrics({'throughput/tokens_per_sec': samples_per_sec * max_seq_len})
261+
logger.log_metrics({'throughput/device/tokens_per_sec': dev_samples_per_sec * max_seq_len})
262+
except AttributeError:
263+
pass
264+
265+
composer_model = state.model
266+
if not isinstance(composer_model, ComposerModel):
267+
composer_model = composer_model.module # Pass through DDP wrapping
268+
if hasattr(composer_model, 'flops_per_batch'):
269+
model_flops_per_batch = composer_model.flops_per_batch # type: ignore
270+
if not isinstance(model_flops_per_batch, Callable):
271+
raise TypeError('flops_per_batch must a callable accepting a batch and '
272+
f'returning an int or float. Instead, got {type(model_flops_per_batch)}.')
273+
flops_per_batch = model_flops_per_batch(state.batch)
274+
flops_per_sec = flops_per_batch * batches_per_sec
275+
logger.log_metrics({'throughput/flops_per_sec': flops_per_sec})
276+
dev_flops_per_sec = flops_per_sec / world_size
277+
logger.log_metrics({'throughput/device/flops_per_sec': dev_flops_per_sec})
278+
if self.gpu_flops_available:
279+
mfu = dev_flops_per_sec / self.gpu_flops_available
280+
logger.log_metrics({'throughput/device/mfu': mfu})
115281

116282
# Log the time
117283
# `state.timestamp` excludes any time spent in evaluation
118284
logger.log_metrics({
119285
'wall_clock/train': state.timestamp.total_wct.total_seconds(),
120286
'wall_clock/val': self.total_eval_wct,
121-
'wall_clock/total': (state.timestamp.total_wct.total_seconds() + self.total_eval_wct),
287+
'wall_clock/total': state.timestamp.total_wct.total_seconds() + self.total_eval_wct,
122288
})
123289

124290
def eval_end(self, state: State, logger: Logger):

composer/trainer/trainer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2104,7 +2104,8 @@ def _train_batch(self, use_grad_scaling: bool) -> Dict[str, torch.Tensor]:
21042104
"""
21052105
assert self._train_data_spec is not None, 'The train data spec should be set on __init__ or fit()'
21062106

2107-
# Cache the device batch, because `self.state.batch` gets overridden in microbatching loop
2107+
# Cache the device batch, because `self.state.batch` gets overridden in microbatching loop.
2108+
# Any in-place changes to a microbatch will be reflected in the device batch.
21082109
device_batch = self.state.batch
21092110

21102111
# Retry until we successfully complete training and return loss
@@ -2212,8 +2213,10 @@ def _train_microbatches(self,
22122213
except TypeError:
22132214
optimizer.zero_grad()
22142215

2215-
# tracker for gradient accumulation
2216+
# Tracker for gradient accumulation
22162217
current_batch_size = sum([self._train_data_spec.get_num_samples_in_batch(batch) for batch in microbatches])
2218+
# Cache batch, which will be overwritten by microbatches. Restore after microbatches complete
2219+
current_batch = self.state.batch
22172220

22182221
for microbatch_idx, self.state.batch in enumerate(microbatches):
22192222
is_final_microbatch = microbatch_idx + 1 == len(microbatches)
@@ -2226,6 +2229,9 @@ def _train_microbatches(self,
22262229
total_loss_dict[loss_key] = self.state.device.tensor_to_device(torch.zeros(size=(1,)))
22272230
total_loss_dict[loss_key] += microbatch_loss
22282231

2232+
# Restore batch
2233+
self.state.batch = current_batch
2234+
22292235
# Unscale gradients before `Event.AFTER_TRAIN_BATCH`
22302236
if use_grad_scaling:
22312237
for optimizer in ensure_tuple(self.state.optimizers):

0 commit comments

Comments
 (0)