Skip to content

Commit 52ead50

Browse files
committed
Add AMP to ImageNet classification script
Signed-off-by: Serge Panev <[email protected]>
1 parent 933f2ea commit 52ead50

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

scripts/classification/imagenet/train_imagenet.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from mxnet import gluon, nd
77
from mxnet import autograd as ag
88
from mxnet.gluon.data.vision import transforms
9+
from mxnet.contrib import amp
910

1011
import gluoncv as gcv
1112
gcv.utils.check_version('0.6.0')
@@ -104,6 +105,8 @@ def parse_args():
104105
help='name of training log file')
105106
parser.add_argument('--use-gn', action='store_true',
106107
help='whether to use group norm.')
108+
parser.add_argument('--amp', action='store_true',
109+
help='Use MXNet AMP for mixed precision training.')
107110
opt = parser.parse_args()
108111
return opt
109112

@@ -121,6 +124,9 @@ def main():
121124

122125
logger.info(opt)
123126

127+
if opt.amp:
128+
amp.init()
129+
124130
batch_size = opt.batch_size
125131
classes = 1000
126132
num_training_samples = 1281167
@@ -347,10 +353,13 @@ def train(ctx):
347353
for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
348354
v.wd_mult = 0.0
349355

350-
trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params)
356+
trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params, update_on_kvstore=(False if opt.amp else None))
351357
if opt.resume_states != '':
352358
trainer.load_states(opt.resume_states)
353359

360+
if opt.amp:
361+
amp.init_trainer(trainer)
362+
354363
if opt.label_smoothing or opt.mixup:
355364
sparse_label_loss = False
356365
else:
@@ -402,8 +411,13 @@ def train(ctx):
402411
p.astype('float32', copy=False)) for yhat, y, p in zip(outputs, label, teacher_prob)]
403412
else:
404413
loss = [L(yhat, y.astype(opt.dtype, copy=False)) for yhat, y in zip(outputs, label)]
405-
for l in loss:
406-
l.backward()
414+
if opt.amp:
415+
with amp.scale_loss(loss, trainer) as scaled_loss:
416+
ag.backward(scaled_loss)
417+
else:
418+
for l in loss:
419+
l.backward()
420+
407421
trainer.step(batch_size)
408422

409423
if opt.mixup:

0 commit comments

Comments
 (0)