7
7
from mxnet import autograd as ag
8
8
from mxnet .gluon import nn
9
9
from mxnet .gluon .data .vision import transforms
10
+ from mxnet .contrib import amp
10
11
11
12
import gluoncv as gcv
12
13
gcv .utils .check_version ('0.6.0' )
@@ -105,6 +106,8 @@ def parse_args():
105
106
help = 'name of training log file' )
106
107
parser .add_argument ('--use-gn' , action = 'store_true' ,
107
108
help = 'whether to use group norm.' )
109
+ parser .add_argument ('--amp' , action = 'store_true' ,
110
+ help = 'Use MXNet AMP for mixed precision training.' )
108
111
opt = parser .parse_args ()
109
112
return opt
110
113
@@ -122,6 +125,9 @@ def main():
122
125
123
126
logger .info (opt )
124
127
128
+ if opt .amp :
129
+ amp .init ()
130
+
125
131
batch_size = opt .batch_size
126
132
classes = 1000
127
133
num_training_samples = 1281167
@@ -349,10 +355,13 @@ def train(ctx):
349
355
for k , v in net .collect_params ('.*beta|.*gamma|.*bias' ).items ():
350
356
v .wd_mult = 0.0
351
357
352
- trainer = gluon .Trainer (net .collect_params (), optimizer , optimizer_params )
358
+ trainer = gluon .Trainer (net .collect_params (), optimizer , optimizer_params , update_on_kvstore = ( False if opt . amp else None ) )
353
359
if opt .resume_states is not '' :
354
360
trainer .load_states (opt .resume_states )
355
361
362
+ if opt .amp :
363
+ amp .init_trainer (trainer )
364
+
356
365
if opt .label_smoothing or opt .mixup :
357
366
sparse_label_loss = False
358
367
else :
@@ -404,8 +413,13 @@ def train(ctx):
404
413
p .astype ('float32' , copy = False )) for yhat , y , p in zip (outputs , label , teacher_prob )]
405
414
else :
406
415
loss = [L (yhat , y .astype (opt .dtype , copy = False )) for yhat , y in zip (outputs , label )]
407
- for l in loss :
408
- l .backward ()
416
+ if opt .amp :
417
+ with amp .scale_loss (loss , trainer ) as scaled_loss :
418
+ ag .backward (scaled_loss )
419
+ else :
420
+ for l in loss :
421
+ l .backward ()
422
+
409
423
trainer .step (batch_size )
410
424
411
425
if opt .mixup :
0 commit comments