Skip to content

Commit 6f68c16

Browse files
authored
Add HealthChecker Callback (#2002)
Adds a GPUHealth checker callback that alerts for anomalous GPU metrics
1 parent e07de7e commit 6f68c16

File tree

4 files changed

+309
-0
lines changed

4 files changed

+309
-0
lines changed

composer/callbacks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from composer.callbacks.checkpoint_saver import CheckpointSaver
1010
from composer.callbacks.early_stopper import EarlyStopper
1111
from composer.callbacks.export_for_inference import ExportForInferenceCallback
12+
from composer.callbacks.health_checker import HealthChecker
1213
from composer.callbacks.image_visualizer import ImageVisualizer
1314
from composer.callbacks.lr_monitor import LRMonitor
1415
from composer.callbacks.memory_monitor import MemoryMonitor
@@ -29,5 +30,6 @@
2930
'ExportForInferenceCallback',
3031
'ThresholdStopper',
3132
'ImageVisualizer',
33+
'HealthChecker',
3234
'RuntimeEstimator',
3335
]

composer/callbacks/health_checker.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# Copyright 2022 MosaicML Composer authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Check GPU Health during training."""
5+
import logging
6+
from collections import deque
7+
from datetime import datetime
8+
from typing import List, Optional, Tuple
9+
10+
import torch
11+
12+
try:
13+
import pynvml
14+
except ImportError:
15+
pynvml = None
16+
17+
import os
18+
19+
import numpy as np
20+
from slack_sdk.webhook import WebhookClient
21+
22+
from composer.core import Callback, State
23+
from composer.core.time import Timestamp
24+
from composer.loggers import Logger
25+
from composer.utils import dist
26+
27+
log = logging.getLogger(__name__)
28+
29+
__all__ = ['HealthChecker']
30+
31+
32+
class HealthChecker(Callback):
33+
"""Checks for GPU health.
34+
35+
This callback checks for GPU health by tracking and alerting for abnormal
36+
GPU utilizations.
37+
38+
For example, if the average utilization during the observation window is,
39+
[30, 30, 45], then the range (45-30=15) would exceed a threshold of 10%.
40+
41+
Args:
42+
threshold (float, optional): Threshold of GPU utilization range to
43+
trigger an alert. Defaults to 10.
44+
sample_freq (int, optional): Sample frequency in seconds. Default: 5.
45+
window_size (int, optional): Window size in seconds. HealthChecker will
46+
check for abnormalities at this frequency. Default: 120.
47+
wait (int, optional): Seconds to wait for starting to sample. Default: 120.
48+
slack_webhook_url (str, optional): Slack URL to send alerts. Can also
49+
be set with the SLACK_WEBHOOK_URL environment variable. Default: None
50+
test_mode (bool, optional): If True, will send a test alert at the first check.
51+
Default: False
52+
"""
53+
54+
def __init__(
55+
self,
56+
threshold: float = 10,
57+
sample_freq: int = 5,
58+
window_size: int = 120,
59+
wait: int = 120,
60+
slack_webhook_url: Optional[str] = None,
61+
test_mode: bool = False,
62+
) -> None:
63+
self.sample_freq = sample_freq
64+
self.window_size = window_size
65+
self.wait = wait
66+
self.slack_webhook_url = slack_webhook_url
67+
self.test_mode = test_mode
68+
69+
if not self.slack_webhook_url:
70+
self.slack_webhook_url = os.environ.get('SLACK_WEBHOOK_URL', None)
71+
72+
self.last_sample = 0
73+
self.last_check = 0
74+
75+
self.metrics = []
76+
if self._is_available():
77+
self.metrics.append(GPUUtilization(threshold))
78+
79+
def init(self, state: State, logger: Logger) -> None:
80+
pass
81+
82+
def after_train_batch(self, state: State, logger: Logger):
83+
if not self.metrics:
84+
return
85+
86+
if self._sample(state.timestamp):
87+
for metric in self.metrics:
88+
metric.sample()
89+
90+
if self._check(state.timestamp):
91+
for metric in self.metrics:
92+
message, alert = metric.check()
93+
if self.test_mode and message:
94+
alert = True
95+
message = '[**THIS IS A TEST**]' + message
96+
if alert and not metric.alerted:
97+
self._alert(message, state)
98+
metric.alerted = True
99+
metric.clear()
100+
101+
def _sample(self, timestamp: Timestamp) -> bool:
102+
now = timestamp.total_wct.seconds
103+
104+
if now < self.wait:
105+
return False
106+
107+
if now - self.last_sample >= self.sample_freq:
108+
self.last_sample = now
109+
return True
110+
111+
return False
112+
113+
def _check(self, timestamp: Timestamp) -> bool:
114+
now = timestamp.total_wct.seconds
115+
116+
if now - self.last_check >= self.window_size:
117+
self.last_check = now
118+
return True
119+
return False
120+
121+
def _alert(self, message: str, state: State) -> None:
122+
prefix = '[{now}][{run_name}][node_rank={node_rank}]'.format(
123+
now=datetime.now(),
124+
run_name=state.run_name,
125+
node_rank=dist.get_node_rank(),
126+
)
127+
128+
node_name = os.environ.get('NODENAME', None)
129+
if node_name is not None:
130+
prefix += f'[node={node_name}]'
131+
132+
message = prefix + ' : ' + message
133+
134+
logging.warning(message)
135+
if self.slack_webhook_url:
136+
client = WebhookClient(url=self.slack_webhook_url)
137+
client.send(text=message)
138+
139+
@staticmethod
140+
def _is_available() -> bool:
141+
if not torch.cuda.is_available():
142+
return False
143+
try:
144+
pynvml.nvmlInit() # type: ignore
145+
return True
146+
except pynvml.NVMLError_LibraryNotFound: # type: ignore
147+
logging.warning('NVML not found, disabling GPU health checking')
148+
except ImportError:
149+
logging.warning('pynvml library not found, disabling GPU health checking.')
150+
except Exception as e:
151+
logging.warning(f'Error initializing NVML: {e}')
152+
153+
return False
154+
155+
156+
class GPUUtilization:
157+
"""GPU Utilization Metric."""
158+
159+
def __init__(self, threshold=10) -> None:
160+
self.samples = deque()
161+
self.threshold = threshold
162+
self.alerted = False
163+
164+
def sample(self) -> None:
165+
if dist.get_local_rank() == 0:
166+
sample = self._sample()
167+
if sample is not None:
168+
self.samples.append(sample)
169+
170+
def _sample(self) -> Optional[List]:
171+
try:
172+
samples = []
173+
device_count = pynvml.nvmlDeviceGetCount() # type: ignore
174+
for i in range(device_count):
175+
handle = pynvml.nvmlDeviceGetHandleByIndex(i) # type: ignore
176+
samples.append(pynvml.nvmlDeviceGetUtilizationRates(handle).gpu) # type: ignore
177+
except pynvml.NVMLError: # type: ignore
178+
return None
179+
return samples
180+
181+
def check(self) -> Tuple[Optional[str], bool]:
182+
if dist.get_local_rank() == 0:
183+
average_sample = np.nanmean(list(self.samples), axis=0)
184+
if np.nanmax(average_sample) - np.nanmin(average_sample) > self.threshold:
185+
message = f'Abnormal GPU utilizations: {average_sample}'
186+
return message, True
187+
else:
188+
message = f':+1: Normal GPU utilizations: {average_sample}'
189+
return message, False
190+
return None, False
191+
192+
def clear(self) -> None:
193+
self.samples.clear()

setup.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,11 @@ def package_files(prefix: str, directory: str, extension: str):
136136
'setuptools<=59.5.0',
137137
]
138138

139+
extra_deps['health_checker'] = {
140+
'pynvml>=11.5.0,<12',
141+
'slack_sdk>=3.19.5,<4',
142+
}
143+
139144
extra_deps['deepspeed'] = [
140145
'deepspeed==0.7.7',
141146
]
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright 2022 MosaicML Composer authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import datetime
5+
from unittest.mock import MagicMock, patch
6+
7+
import pytest
8+
9+
from composer import Timestamp
10+
from composer.callbacks import HealthChecker
11+
from composer.callbacks.health_checker import GPUUtilization
12+
from composer.utils import dist
13+
from tests.common import world_size
14+
15+
pynvml = pytest.importorskip('pynvml')
16+
pytest.importorskip('slack_sdk')
17+
18+
19+
class MockUtil:
20+
21+
def __init__(self, util):
22+
self.gpu = util
23+
24+
25+
@pytest.mark.gpu
26+
@world_size(1, 2)
27+
def test_gpu_utilization(world_size):
28+
assert HealthChecker._is_available()
29+
30+
gpu_utilization_values = [
31+
MockUtil(100),
32+
MockUtil(10),
33+
MockUtil(100),
34+
MockUtil(100),
35+
MockUtil(100),
36+
MockUtil(100),
37+
]
38+
39+
with patch.multiple(pynvml,
40+
nvmlDeviceGetUtilizationRates=MagicMock(side_effect=gpu_utilization_values),
41+
nvmlDeviceGetCount=MagicMock(return_value=world_size)):
42+
43+
gpu_utilization = GPUUtilization()
44+
gpu_utilization.sample()
45+
gpu_utilization.sample()
46+
gpu_utilization.sample()
47+
_, alert = gpu_utilization.check()
48+
49+
should_alert = dist.get_local_rank() == 0 and world_size > 1
50+
assert alert == should_alert
51+
52+
53+
@pytest.mark.gpu
54+
@world_size(1, 2)
55+
def test_health_checker(world_size):
56+
57+
state = MagicMock()
58+
state.run_name = 'pytest-mock-run-kwei73'
59+
logger = MagicMock()
60+
61+
health_checker = HealthChecker(
62+
sample_freq=1,
63+
window_size=3,
64+
wait=0,
65+
)
66+
67+
gpu_utilization_values = [
68+
MockUtil(100),
69+
MockUtil(10),
70+
MockUtil(100),
71+
MockUtil(100),
72+
MockUtil(100),
73+
MockUtil(100),
74+
]
75+
76+
with patch.multiple(pynvml,
77+
nvmlDeviceGetUtilizationRates=MagicMock(side_effect=gpu_utilization_values),
78+
nvmlDeviceGetCount=MagicMock(return_value=world_size)):
79+
80+
# collect data and checker
81+
for seconds in [1, 2, 3]:
82+
state.timestamp = Timestamp(total_wct=datetime.timedelta(seconds=seconds))
83+
health_checker.after_train_batch(state, logger)
84+
85+
should_alert = dist.get_local_rank() == 0 and world_size > 1
86+
assert health_checker.metrics[0].alerted == should_alert
87+
88+
89+
def test_health_checker_sampling():
90+
timestamp = Timestamp(total_wct=datetime.timedelta(seconds=0))
91+
92+
health_checker = HealthChecker(
93+
sample_freq=1,
94+
window_size=5,
95+
wait=10,
96+
)
97+
98+
config = [
99+
(5, False), # before wait
100+
(11, True),
101+
(11.5, False), # below sample frequency
102+
(12, True),
103+
(20, True),
104+
(11, False), # no time travel
105+
]
106+
107+
for seconds, is_sample in config:
108+
timestamp = Timestamp(total_wct=datetime.timedelta(seconds=seconds))
109+
assert health_checker._sample(timestamp) == is_sample

0 commit comments

Comments
 (0)