Skip to content

Commit 0926cbf

Browse files
authored
reset scaler (#1999)
1 parent 850e34d commit 0926cbf

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

composer/trainer/trainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import re
1616
import time
1717
import warnings
18+
from collections import defaultdict
1819
from copy import deepcopy
1920
from pathlib import Path
2021
from typing import Any, Callable, ContextManager, Dict, Iterable, List, Optional, Sequence, TextIO, Tuple, Union, cast
@@ -24,7 +25,7 @@
2425
import torch.distributed
2526
import torch.nn as nn
2627
import torch.utils.data
27-
from torch.cuda.amp.grad_scaler import GradScaler
28+
from torch.cuda.amp.grad_scaler import GradScaler, _refresh_per_optimizer_state
2829
from torch.nn.parallel import DistributedDataParallel
2930
from torch.utils.data import DataLoader, DistributedSampler
3031
from torchmetrics import Metric
@@ -257,6 +258,8 @@ def _adjust_grad_accum(state: State, device_batch_size: int):
257258
del state.loss
258259
for optimizer in state.optimizers:
259260
optimizer.zero_grad(set_to_none=True)
261+
if state.scaler is not None:
262+
state.scaler._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
260263
torch.cuda.empty_cache()
261264

262265

@@ -285,6 +288,8 @@ def _adjust_device_train_microbatch_size(state: State):
285288
del state.loss
286289
for optimizer in state.optimizers:
287290
optimizer.zero_grad(set_to_none=True)
291+
if state.scaler is not None:
292+
state.scaler._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
288293
torch.cuda.empty_cache()
289294

290295

0 commit comments

Comments
 (0)