4
4
"""Monitor throughput during training."""
5
5
from __future__ import annotations
6
6
7
+ import warnings
7
8
from collections import deque
8
- from typing import Any , Deque , Dict
9
+ from typing import Any , Callable , Deque , Dict , Optional , Union
10
+
11
+ import torch
9
12
10
13
from composer .core import Callback , State
11
14
from composer .loggers import Logger
15
+ from composer .models .base import ComposerModel
16
+ from composer .utils import dist
12
17
13
18
__all__ = ['SpeedMonitor' ]
14
19
20
+ GPU_AVAILABLE_FLOPS = {
21
+ # source: https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet
22
+ # nvidia publishes spec sheet with a 2x sparsity factor
23
+ 'h100-sxm' : {
24
+ 'fp64' : 67e12 ,
25
+ 'fp32' : 67e12 ,
26
+ 'tf32' : 989e12 / 2 ,
27
+ 'fp16' : 1.979e15 / 2 ,
28
+ 'amp_fp16' : 1.979e15 / 2 ,
29
+ 'bf16' : 1.979e15 / 2 ,
30
+ 'amp_bf16' : 1.979e15 / 2 ,
31
+ 'fp8' : 3.958e15 / 2 ,
32
+ 'amp_fp8' : 3.958e15 / 2 ,
33
+ 'int8' : 3.958e15 / 2 ,
34
+ },
35
+ 'h100-pcie' : {
36
+ 'fp64' : 51e12 ,
37
+ 'fp32' : 51e12 ,
38
+ 'tf32' : 756e12 / 2 ,
39
+ 'fp16' : 1.513e15 / 2 ,
40
+ 'amp_fp16' : 1.513e15 / 2 ,
41
+ 'bf16' : 1.513e15 / 2 ,
42
+ 'amp_bf16' : 1.513e15 / 2 ,
43
+ 'fp8' : 3.026e15 / 2 ,
44
+ 'amp_fp8' : 3.026e15 / 2 ,
45
+ 'int8' : 3.026e15 / 2 ,
46
+ },
47
+ # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
48
+ # sxm and pcie have same flop counts
49
+ 'a100' : {
50
+ 'fp64' : 19.5e12 ,
51
+ 'fp32' : 19.5e12 ,
52
+ 'tf32' : 156e12 ,
53
+ 'fp16' : 312e12 ,
54
+ 'amp_fp16' : 312e12 ,
55
+ 'bf16' : 312e12 ,
56
+ 'amp_bf16' : 312e12 ,
57
+ },
58
+ # source: https://images.nvidia.com/content/technologies/volta/pdf/volta-v100-datasheet-update-us-1165301-r5.pdf
59
+ 'v100-sxm' : {
60
+ 'fp64' : 7.8e12 ,
61
+ 'fp32' : 15.7e12 ,
62
+ 'fp16' : 125e12 ,
63
+ 'amp_fp16' : 125e12 ,
64
+ },
65
+ 'v100-pcie' : {
66
+ 'fp64' : 7e12 ,
67
+ 'fp32' : 14e12 ,
68
+ 'fp16' : 112e12 ,
69
+ 'amp_fp16' : 112e12 ,
70
+ },
71
+ 'v100s-pcie' : {
72
+ 'fp64' : 8.2e12 ,
73
+ 'fp32' : 16.4e12 ,
74
+ 'fp16' : 130e12 ,
75
+ 'amp_fp16' : 130e12 ,
76
+ },
77
+ # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf
78
+ # sxm and pcie have same flop counts
79
+ 't4' : {
80
+ 'fp32' : 8.1e12 ,
81
+ 'fp16' : 65e12 ,
82
+ 'amp_fp16' : 65e12 ,
83
+ 'int8' : 130e12 ,
84
+ 'int4' : 260e12 ,
85
+ },
86
+ }
87
+
88
+
89
+ def get_gpu_flops_available (state : State ):
90
+ gpu_flops_available = None
91
+
92
+ # Return 0 if no CUDA device (e.g., when running with CPU only)
93
+ if not torch .cuda .is_available ():
94
+ return 0
95
+
96
+ # torch.cuda.get_device_name() ex output: 'NVIDIA A100-SXM4-40GB'
97
+ device_name = torch .cuda .get_device_name ().lower ()
98
+ if 'h100-sxm' in device_name :
99
+ device_name = 'h100-sxm'
100
+ elif 'h100-pcie' in device_name :
101
+ device_name = 'h100-pcie'
102
+ elif 'a100' in device_name :
103
+ device_name = 'a100'
104
+ elif 'v100-sxm' in device_name :
105
+ device_name = 'v100-sxm'
106
+ elif 'v100-pcie' in device_name :
107
+ device_name = 'v100-pcie'
108
+ elif 't4' in device_name :
109
+ device_name = 't4'
110
+ else :
111
+ device_name = None
112
+
113
+ if device_name is not None :
114
+ try :
115
+ gpu_flops_available = int (GPU_AVAILABLE_FLOPS [device_name ][state .precision .value ])
116
+ except :
117
+ gpu_flops_available = None
118
+
119
+ if gpu_flops_available is None :
120
+ warnings .warn (
121
+ f'gpu_flop count not found for { device_name } with precision: { state .precision .value } ; ' + \
122
+ f'MFU cannot be calculated and reported. gpu_flops_available can be manually' + \
123
+ f'overridden by setting gpu_flops_available in SpeedMonitor.'
124
+ )
125
+ # Setting to 0 will disable MFU computation and prevent
126
+ # the speed monitor from running this helper every batch
127
+ gpu_flops_available = 0
128
+
129
+ return gpu_flops_available
130
+
15
131
16
132
class SpeedMonitor (Callback ):
17
133
"""Logs the training throughput.
18
134
19
- The training throughput in terms of number of samples per second is logged on the
20
- :attr:`.Event.BATCH_END` event if we have reached the ``window_size`` threshold.
135
+ The training throughput is logged on the :attr:`.Event.BATCH_END` event once we have reached
136
+ the `window_size` threshold. If a model has `flops_per_batch` attribute, then flops per second
137
+ is also logged. If running on a known GPU type or if `gpu_flops_available` is set, then MFU is
138
+ also logged. All metrics are also logged as per device by dividing by world size.
21
139
22
- The wall clock train time is logged on every :attr:`.Event.BATCH_END` event.
140
+ To compute `flops_per_sec`, the model attribute `flops_per_batch` should be set to a callable
141
+ which accepts a batch and returns the number of flops for that batch. Typically, this should
142
+ be flops per sample times the batch size unless pad tokens are used.
23
143
24
- The average throughout over an epoch is logged on the :attr:`.Event.EPOCH_END ` event.
144
+ The wall clock time is logged on every :attr:`.Event.BATCH_END ` event.
25
145
26
146
Example:
27
147
.. doctest::
@@ -41,84 +161,130 @@ class SpeedMonitor(Callback):
41
161
The training throughput is logged by the :class:`.Logger` to the following keys as
42
162
described below.
43
163
44
- +----------------------------------+-------------------------------------------------------------+
45
- | Key | Logged data |
46
- +==================================+=============================================================+
47
- | | Rolling average (over ``window_size`` most recent |
48
- | ``throughput/samples_per_sec`` | batches) of the number of samples processed per second |
49
- | | |
50
- +----------------------------------+-------------------------------------------------------------+
51
- | ``wall_clock/train`` | Total elapsed training time |
52
- +----------------------------------+-------------------------------------------------------------+
53
- | ``wall_clock/val`` | Total elapsed validation time |
54
- +----------------------------------+-------------------------------------------------------------+
55
- | ``wall_clock/total`` | Total elapsed time (wall_clock/train + wall_clock/val) |
56
- +----------------------------------+-------------------------------------------------------------+
164
+ +-------------------------------------+-----------------------------------------------------------+
165
+ | Key | Logged data |
166
+ +=====================================+===========================================================+
167
+ | | Rolling average (over `window_size` most recent |
168
+ | `throughput/batches_per_sec` | batches) of the number of batches processed per second |
169
+ | | |
170
+ +-------------------------------------+-----------------------------------------------------------+
171
+ | | Rolling average (over `window_size` most recent |
172
+ | `throughput/samples_per_sec` | batches) of the number of samples processed per second |
173
+ | | |
174
+ +-------------------------------------+-----------------------------------------------------------+
175
+ | | Rolling average (over `window_size` most recent |
176
+ | `throughput/tokens_per_sec` | batches) of the number of tokens processed per second. |
177
+ | | Only logged when dataloader.dataset has `max_seq_len`. |
178
+ | | This may include padding depending on dataset |
179
+ +-------------------------------------+-----------------------------------------------------------+
180
+ | | Estimates flops by `flops_per_batch * batches_per_sec` |
181
+ | `throughput/flops_per_sec` | if model has attribute `flops_per_batch` |
182
+ | | |
183
+ +-------------------------------------+-----------------------------------------------------------+
184
+ | `throughput/device/batches_per_sec` | `throughput/batches_per_sec` divided by world size |
185
+ +-------------------------------------+-----------------------------------------------------------+
186
+ | `throughput/device/samples_per_sec` | `throughput/samples_per_sec` divided by world size |
187
+ +-------------------------------------+-----------------------------------------------------------+
188
+ | | `throughput/tokens_per_sec` divided by world size. Only |
189
+ | `throughput/device/tokens_per_sec` | logged when dataloader.dataset has `max_seq_len`. This |
190
+ | | may include pad tokens depending on dataset |
191
+ +-------------------------------------+-----------------------------------------------------------+
192
+ | | `throughput/flops_per_sec` divided by world size. Only |
193
+ | `throughput/device/flops_per_sec` | logged when model has attribute `flops_per_batch` |
194
+ | | |
195
+ +-------------------------------------+-----------------------------------------------------------+
196
+ | | `throughput/device/flops_per_sec` divided by world size. |
197
+ | `throughput/device/mfu` | Only logged when model has attribute `flops_per_batch` |
198
+ | | and `gpu_flops_available`, which can be passed as an |
199
+ | | argument if not automatically determined by SpeedMonitor |
200
+ +-------------------------------------+-----------------------------------------------------------+
201
+ | `wall_clock/train` | Total elapsed training time |
202
+ +-------------------------------------+-----------------------------------------------------------+
203
+ | `wall_clock/val` | Total elapsed validation time |
204
+ +-------------------------------------+-----------------------------------------------------------+
205
+ | `wall_clock/total` | Total elapsed time (wall_clock/train + wall_clock/val) |
206
+ +-------------------------------------+-----------------------------------------------------------+
57
207
58
208
Args:
59
209
window_size (int, optional): Number of batches to use for a rolling average of throughput.
60
210
Defaults to 100.
61
211
"""
62
212
63
- def __init__ (self , window_size : int = 100 ):
213
+ def __init__ (self , window_size : int = 100 , gpu_flops_available : Optional [ Union [ float , int ]] = None ):
64
214
# Track the batch num samples and wct to compute throughput over a window of batches
65
- self .batch_start_num_samples = 0
66
- self .batch_start_wct = 0.0
67
- self .batch_wct_buffer : Deque [float ] = deque (maxlen = window_size )
68
- self .batch_num_samples_buffer : Deque [int ] = deque (maxlen = window_size )
69
- self .window_size = window_size
215
+ self .history_samples : Deque [int ] = deque (maxlen = window_size + 1 )
216
+ self .history_wct : Deque [float ] = deque (maxlen = window_size + 1 )
217
+
218
+ self .gpu_flops_available = gpu_flops_available
70
219
71
220
# Keep track of time spent evaluating
72
221
self .total_eval_wct = 0.0
73
222
74
223
def state_dict (self ) -> Dict [str , Any ]:
75
224
return {
76
- 'batch_start_num_samples' : self .batch_start_num_samples ,
77
- 'batch_start_wct' : self .batch_start_wct ,
78
- 'batch_wct_buffer' : self .batch_wct_buffer ,
79
- 'batch_num_samples_buffer' : self .batch_num_samples_buffer ,
80
- # "window_wct": self.window_wct,
81
- # "window_num_samples": self.window_num_samples,
82
225
'total_eval_wct' : self .total_eval_wct ,
83
226
}
84
227
85
228
def load_state_dict (self , state : Dict [str , Any ]) -> None :
86
- self .batch_start_num_samples = state ['batch_start_num_samples' ]
87
- self .batch_start_wct = state ['batch_start_wct' ]
88
- self .batch_wct_buffer = deque (
89
- [x for x in state ['batch_wct_buffer' ]],
90
- maxlen = self .window_size ,
91
- )
92
- self .batch_num_samples_buffer = deque (
93
- [x for x in state ['batch_num_samples_buffer' ]],
94
- maxlen = self .window_size ,
95
- )
96
229
self .total_eval_wct = state ['total_eval_wct' ]
97
230
98
- def before_dataloader (self , state : State , logger : Logger ) -> None :
231
+ def init (self , state : State , logger : Logger ) -> None :
99
232
del logger # unused
100
- self .batch_start_wct = state . timestamp . total_wct . total_seconds ()
101
- self .batch_start_num_samples = int (state . timestamp . sample )
233
+ if self .gpu_flops_available is None :
234
+ self .gpu_flops_available = get_gpu_flops_available (state )
102
235
103
236
def batch_end (self , state : State , logger : Logger ):
104
- batch_num_samples = int (state .timestamp .sample ) - self .batch_start_num_samples
105
- batch_wct = state .timestamp .total_wct .total_seconds () - self .batch_start_wct
106
-
107
237
# Add the new element
108
- self .batch_wct_buffer .append (batch_wct )
109
- self .batch_num_samples_buffer .append (batch_num_samples )
238
+ self .history_samples .append (state . timestamp . sample . value )
239
+ self .history_wct .append (state . timestamp . total_wct . total_seconds () )
110
240
111
241
# Log the throughput
112
- if len (self .batch_num_samples_buffer ) == self .window_size :
113
- throughput = sum (self .batch_num_samples_buffer ) / sum (self .batch_wct_buffer )
114
- logger .log_metrics ({'throughput/samples_per_sec' : throughput })
242
+ if len (self .history_wct ) == self .history_wct .maxlen :
243
+ world_size = dist .get_world_size ()
244
+ elapsed_batches = len (self .history_samples ) - 1
245
+ elapsed_samples = int (self .history_samples [- 1 ]) - int (self .history_samples [0 ])
246
+ elapsed_wct = self .history_wct [- 1 ] - self .history_wct [0 ]
247
+ batches_per_sec = elapsed_batches / elapsed_wct
248
+ samples_per_sec = elapsed_samples / elapsed_wct
249
+ dev_batches_per_sec = batches_per_sec / world_size
250
+ dev_samples_per_sec = samples_per_sec / world_size
251
+ logger .log_metrics ({'throughput/batches_per_sec' : batches_per_sec })
252
+ logger .log_metrics ({'throughput/samples_per_sec' : samples_per_sec })
253
+ logger .log_metrics ({'throughput/device/batches_per_sec' : dev_batches_per_sec })
254
+ logger .log_metrics ({'throughput/device/samples_per_sec' : dev_samples_per_sec })
255
+
256
+ # Compute token stats if dataloader.dataset has max_seq_len. Assumes no padding.
257
+ try :
258
+ max_seq_len = state .dataloader .dataset .max_seq_len # type: ignore
259
+ # Only applicable to seq data / models
260
+ logger .log_metrics ({'throughput/tokens_per_sec' : samples_per_sec * max_seq_len })
261
+ logger .log_metrics ({'throughput/device/tokens_per_sec' : dev_samples_per_sec * max_seq_len })
262
+ except AttributeError :
263
+ pass
264
+
265
+ composer_model = state .model
266
+ if not isinstance (composer_model , ComposerModel ):
267
+ composer_model = composer_model .module # Pass through DDP wrapping
268
+ if hasattr (composer_model , 'flops_per_batch' ):
269
+ model_flops_per_batch = composer_model .flops_per_batch # type: ignore
270
+ if not isinstance (model_flops_per_batch , Callable ):
271
+ raise TypeError ('flops_per_batch must a callable accepting a batch and '
272
+ f'returning an int or float. Instead, got { type (model_flops_per_batch )} .' )
273
+ flops_per_batch = model_flops_per_batch (state .batch )
274
+ flops_per_sec = flops_per_batch * batches_per_sec
275
+ logger .log_metrics ({'throughput/flops_per_sec' : flops_per_sec })
276
+ dev_flops_per_sec = flops_per_sec / world_size
277
+ logger .log_metrics ({'throughput/device/flops_per_sec' : dev_flops_per_sec })
278
+ if self .gpu_flops_available :
279
+ mfu = dev_flops_per_sec / self .gpu_flops_available
280
+ logger .log_metrics ({'throughput/device/mfu' : mfu })
115
281
116
282
# Log the time
117
283
# `state.timestamp` excludes any time spent in evaluation
118
284
logger .log_metrics ({
119
285
'wall_clock/train' : state .timestamp .total_wct .total_seconds (),
120
286
'wall_clock/val' : self .total_eval_wct ,
121
- 'wall_clock/total' : ( state .timestamp .total_wct .total_seconds () + self .total_eval_wct ) ,
287
+ 'wall_clock/total' : state .timestamp .total_wct .total_seconds () + self .total_eval_wct ,
122
288
})
123
289
124
290
def eval_end (self , state : State , logger : Logger ):
0 commit comments