Skip to content

Commit c2390f6

Browse files
committed
Add AMP to ImageNet segmentation script
Signed-off-by: Serge Panev <[email protected]>
1 parent 31f4651 commit c2390f6

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

scripts/segmentation/train.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import mxnet as mx
99
from mxnet import gluon, autograd
1010
from mxnet.gluon.data.vision import transforms
11+
from mxnet.contrib import amp
1112

1213
import gluoncv
1314
gluoncv.utils.check_version('0.6.0')
@@ -95,7 +96,9 @@ def parse_args():
9596
help='using Synchronized Cross-GPU BatchNorm')
9697
# the parser
9798
args = parser.parse_args()
98-
99+
# performance related
100+
parser.add_argument('--amp', action='store_true',
101+
help='Use MXNet AMP for mixed precision training.')
99102
# handle contexts
100103
if args.no_cuda:
101104
print('Using CPU')
@@ -200,7 +203,11 @@ def __init__(self, args, logger):
200203
v.wd_mult = 0.0
201204

202205
self.optimizer = gluon.Trainer(self.net.module.collect_params(), 'sgd',
203-
optimizer_params, kvstore=kv)
206+
optimizer_params, kvstore=(False if args.amp else None))
207+
208+
209+
if args.amp:
210+
amp.init_trainer(trainer)
204211
# evaluation metrics
205212
self.metric = gluoncv.utils.metrics.SegmentationMetric(trainset.num_class)
206213

@@ -212,7 +219,11 @@ def training(self, epoch):
212219
outputs = self.net(data.astype(args.dtype, copy=False))
213220
losses = self.criterion(outputs, target)
214221
mx.nd.waitall()
215-
autograd.backward(losses)
222+
if args.amp:
223+
with amp.scale_loss(losses, self.optimizer) as scaled_losses:
224+
autograd.backward(scaled_losses)
225+
else:
226+
autograd.backward(losses)
216227
self.optimizer.step(self.args.batch_size)
217228
for loss in losses:
218229
train_loss += np.mean(loss.asnumpy()) / len(losses)
@@ -252,7 +263,8 @@ def save_checkpoint(net, args, epoch, mIoU, is_best=False):
252263

253264
if __name__ == "__main__":
254265
args = parse_args()
255-
266+
if args.amp:
267+
amp.init()
256268
# build logger
257269
filehandler = logging.FileHandler(os.path.join(args.save_dir, args.logging_file))
258270
streamhandler = logging.StreamHandler()

0 commit comments

Comments
 (0)