6
6
from mxnet import gluon , nd
7
7
from mxnet import autograd as ag
8
8
from mxnet .gluon .data .vision import transforms
9
+ from mxnet .contrib import amp
9
10
10
11
import gluoncv as gcv
11
12
gcv .utils .check_version ('0.6.0' )
@@ -104,6 +105,8 @@ def parse_args():
104
105
help = 'name of training log file' )
105
106
parser .add_argument ('--use-gn' , action = 'store_true' ,
106
107
help = 'whether to use group norm.' )
108
+ parser .add_argument ('--amp' , action = 'store_true' ,
109
+ help = 'Use MXNet AMP for mixed precision training.' )
107
110
opt = parser .parse_args ()
108
111
return opt
109
112
@@ -121,6 +124,9 @@ def main():
121
124
122
125
logger .info (opt )
123
126
127
+ if opt .amp :
128
+ amp .init ()
129
+
124
130
batch_size = opt .batch_size
125
131
classes = 1000
126
132
num_training_samples = 1281167
@@ -347,10 +353,13 @@ def train(ctx):
347
353
for k , v in net .collect_params ('.*beta|.*gamma|.*bias' ).items ():
348
354
v .wd_mult = 0.0
349
355
350
- trainer = gluon .Trainer (net .collect_params (), optimizer , optimizer_params )
356
+ trainer = gluon .Trainer (net .collect_params (), optimizer , optimizer_params , update_on_kvstore = ( False if opt . amp else None ) )
351
357
if opt .resume_states != '' :
352
358
trainer .load_states (opt .resume_states )
353
359
360
+ if opt .amp :
361
+ amp .init_trainer (trainer )
362
+
354
363
if opt .label_smoothing or opt .mixup :
355
364
sparse_label_loss = False
356
365
else :
@@ -402,8 +411,13 @@ def train(ctx):
402
411
p .astype ('float32' , copy = False )) for yhat , y , p in zip (outputs , label , teacher_prob )]
403
412
else :
404
413
loss = [L (yhat , y .astype (opt .dtype , copy = False )) for yhat , y in zip (outputs , label )]
405
- for l in loss :
406
- l .backward ()
414
+ if opt .amp :
415
+ with amp .scale_loss (loss , trainer ) as scaled_loss :
416
+ ag .backward (scaled_loss )
417
+ else :
418
+ for l in loss :
419
+ l .backward ()
420
+
407
421
trainer .step (batch_size )
408
422
409
423
if opt .mixup :
0 commit comments