8
8
import mxnet as mx
9
9
from mxnet import gluon , autograd
10
10
from mxnet .gluon .data .vision import transforms
11
+ from mxnet .contrib import amp
11
12
12
13
import gluoncv
13
14
gluoncv .utils .check_version ('0.6.0' )
@@ -95,7 +96,9 @@ def parse_args():
95
96
help = 'using Synchronized Cross-GPU BatchNorm' )
96
97
# the parser
97
98
args = parser .parse_args ()
98
-
99
+ # performance related
100
+ parser .add_argument ('--amp' , action = 'store_true' ,
101
+ help = 'Use MXNet AMP for mixed precision training.' )
99
102
# handle contexts
100
103
if args .no_cuda :
101
104
print ('Using CPU' )
@@ -200,7 +203,11 @@ def __init__(self, args, logger):
200
203
v .wd_mult = 0.0
201
204
202
205
self .optimizer = gluon .Trainer (self .net .module .collect_params (), 'sgd' ,
203
- optimizer_params , kvstore = kv )
206
+ optimizer_params , kvstore = (False if args .amp else None ))
207
+
208
+
209
+ if args .amp :
210
+ amp .init_trainer (trainer )
204
211
# evaluation metrics
205
212
self .metric = gluoncv .utils .metrics .SegmentationMetric (trainset .num_class )
206
213
@@ -212,7 +219,11 @@ def training(self, epoch):
212
219
outputs = self .net (data .astype (args .dtype , copy = False ))
213
220
losses = self .criterion (outputs , target )
214
221
mx .nd .waitall ()
215
- autograd .backward (losses )
222
+ if args .amp :
223
+ with amp .scale_loss (losses , self .optimizer ) as scaled_losses :
224
+ autograd .backward (scaled_losses )
225
+ else :
226
+ autograd .backward (losses )
216
227
self .optimizer .step (self .args .batch_size )
217
228
for loss in losses :
218
229
train_loss += np .mean (loss .asnumpy ()) / len (losses )
@@ -252,7 +263,8 @@ def save_checkpoint(net, args, epoch, mIoU, is_best=False):
252
263
253
264
if __name__ == "__main__" :
254
265
args = parse_args ()
255
-
266
+ if args .amp :
267
+ amp .init ()
256
268
# build logger
257
269
filehandler = logging .FileHandler (os .path .join (args .save_dir , args .logging_file ))
258
270
streamhandler = logging .StreamHandler ()
0 commit comments