Skip to content

Commit 31f4651

Browse files
committed
Add AMP to ImageNet classification script
Signed-off-by: Serge Panev <[email protected]>
1 parent 9921323 commit 31f4651

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

scripts/classification/imagenet/train_imagenet.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from mxnet import autograd as ag
88
from mxnet.gluon import nn
99
from mxnet.gluon.data.vision import transforms
10+
from mxnet.contrib import amp
1011

1112
import gluoncv as gcv
1213
gcv.utils.check_version('0.6.0')
@@ -105,6 +106,8 @@ def parse_args():
105106
help='name of training log file')
106107
parser.add_argument('--use-gn', action='store_true',
107108
help='whether to use group norm.')
109+
parser.add_argument('--amp', action='store_true',
110+
help='Use MXNet AMP for mixed precision training.')
108111
opt = parser.parse_args()
109112
return opt
110113

@@ -122,6 +125,9 @@ def main():
122125

123126
logger.info(opt)
124127

128+
if opt.amp:
129+
amp.init()
130+
125131
batch_size = opt.batch_size
126132
classes = 1000
127133
num_training_samples = 1281167
@@ -349,10 +355,13 @@ def train(ctx):
349355
for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
350356
v.wd_mult = 0.0
351357

352-
trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params)
358+
trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params, update_on_kvstore=(False if opt.amp else None))
353359
if opt.resume_states is not '':
354360
trainer.load_states(opt.resume_states)
355361

362+
if opt.amp:
363+
amp.init_trainer(trainer)
364+
356365
if opt.label_smoothing or opt.mixup:
357366
sparse_label_loss = False
358367
else:
@@ -404,8 +413,13 @@ def train(ctx):
404413
p.astype('float32', copy=False)) for yhat, y, p in zip(outputs, label, teacher_prob)]
405414
else:
406415
loss = [L(yhat, y.astype(opt.dtype, copy=False)) for yhat, y in zip(outputs, label)]
407-
for l in loss:
408-
l.backward()
416+
if opt.amp:
417+
with amp.scale_loss(loss, trainer) as scaled_loss:
418+
ag.backward(scaled_loss)
419+
else:
420+
for l in loss:
421+
l.backward()
422+
409423
trainer.step(batch_size)
410424

411425
if opt.mixup:

0 commit comments

Comments
 (0)