Skip to content

Commit c034a23

Browse files
arbabu123facebook-github-bot
authored andcommitted
Fix broken gradients logging and add lr logging to tensorboard (facebookresearch#1158)
Summary: Pull Request resolved: facebookresearch#1158 This should help to monitor lr when using warmup/annealing etc Reviewed By: geof90 Differential Revision: D18624642 fbshipit-source-id: 0cf55150f40c8a3ddf459d9d968f15f58356c488
1 parent 98e6761 commit c034a23

File tree

3 files changed

+20
-8
lines changed

3 files changed

+20
-8
lines changed

pytext/metric_reporters/channel.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def report(
171171
meta,
172172
model,
173173
*args,
174+
optimizer=None,
174175
):
175176
"""
176177
Defines how to format and report data to TensorBoard using the summary
@@ -213,17 +214,19 @@ def report(
213214
self.add_scalars(prefix, metrics, epoch)
214215

215216
if stage == Stage.TRAIN:
217+
if optimizer is not None:
218+
for idx, param_group in enumerate(optimizer.param_groups):
219+
self.summary_writer.add_scalar(
220+
f"optimizer.lr.param_group.{idx}", param_group["lr"], epoch
221+
)
216222
for key, val in model.named_parameters():
217223
if val is not None and len(val) > 0 and not (val == 0).all():
218224
limit = 9.9e19
225+
grad = val.grad
219226
val = torch.clamp(val.float(), -limit, limit)
220227
self.summary_writer.add_histogram(key, val, epoch)
221-
if (
222-
val.grad is not None
223-
and len(val.grad) > 0
224-
and not (val.grad == 0).all()
225-
):
226-
grad = torch.clamp(val.grad.float(), -limit, limit)
228+
if grad is not None and len(grad) > 0 and not (grad == 0).all():
229+
grad = torch.clamp(grad.float(), -limit, limit)
227230
self.summary_writer.add_histogram(
228231
key + "_gradients", grad, epoch
229232
)

pytext/metric_reporters/metric_reporter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,9 @@ def get_meta(self):
206206
"""
207207
return {}
208208

209-
def report_metric(self, model, stage, epoch, reset=True, print_to_channels=True):
209+
def report_metric(
210+
self, model, stage, epoch, reset=True, print_to_channels=True, optimizer=None
211+
):
210212
"""
211213
Calculate metrics and average loss, report all statistic data to channels
212214
@@ -241,6 +243,7 @@ def report_metric(self, model, stage, epoch, reset=True, print_to_channels=True)
241243
self.all_context,
242244
self.get_meta(),
243245
model,
246+
optimizer,
244247
)
245248

246249
if reset:

pytext/trainers/trainer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,13 @@ def run_epoch(
509509
if report_metric:
510510
with timing.time("report metrics"):
511511
metrics = metric_reporter.report_metric(
512-
model, state.stage, state.epoch, print_to_channels=(state.rank == 0)
512+
model,
513+
state.stage,
514+
state.epoch,
515+
print_to_channels=(state.rank == 0),
516+
optimizer=getattr(
517+
state, "optimizer", None
518+
), # optimizer is not present during test
513519
)
514520
else:
515521
metric_reporter._reset()

0 commit comments

Comments
 (0)