Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit 9ba9838

Browse files
arbabu123facebook-github-bot
authored andcommitted
Fix broken gradients logging and add lr logging to tensorboard
Differential Revision: D18624642 fbshipit-source-id: c870ede41701edcdfc63405185b66d5a8ac418b6
1 parent 57b7dc3 commit 9ba9838

File tree

3 files changed

+17
-8
lines changed

3 files changed

+17
-8
lines changed

pytext/metric_reporters/channel.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ def report(
170170
context,
171171
meta,
172172
model,
173+
optimizer,
173174
*args,
174175
):
175176
"""
@@ -213,17 +214,18 @@ def report(
213214
self.add_scalars(prefix, metrics, epoch)
214215

215216
if stage == Stage.TRAIN:
217+
for idx, param_group in enumerate(optimizer.param_groups):
218+
self.summary_writer.add_scalar(
219+
f"optimizer.lr.param_group.{idx}", param_group["lr"], epoch
220+
)
216221
for key, val in model.named_parameters():
217222
if val is not None and len(val) > 0 and not (val == 0).all():
218223
limit = 9.9e19
224+
grad = val.grad
219225
val = torch.clamp(val.float(), -limit, limit)
220226
self.summary_writer.add_histogram(key, val, epoch)
221-
if (
222-
val.grad is not None
223-
and len(val.grad) > 0
224-
and not (val.grad == 0).all()
225-
):
226-
grad = torch.clamp(val.grad.float(), -limit, limit)
227+
if grad is not None and len(grad) > 0 and not (grad == 0).all():
228+
grad = torch.clamp(grad.float(), -limit, limit)
227229
self.summary_writer.add_histogram(
228230
key + "_gradients", grad, epoch
229231
)

pytext/metric_reporters/metric_reporter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,9 @@ def get_meta(self):
206206
"""
207207
return {}
208208

209-
def report_metric(self, model, stage, epoch, reset=True, print_to_channels=True):
209+
def report_metric(
210+
self, model, optimizer, stage, epoch, reset=True, print_to_channels=True
211+
):
210212
"""
211213
Calculate metrics and average loss, report all statistic data to channels
212214
@@ -241,6 +243,7 @@ def report_metric(self, model, stage, epoch, reset=True, print_to_channels=True)
241243
self.all_context,
242244
self.get_meta(),
243245
model,
246+
optimizer,
244247
)
245248

246249
if reset:

pytext/trainers/trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,11 @@ def run_epoch(
509509
if report_metric:
510510
with timing.time("report metrics"):
511511
metrics = metric_reporter.report_metric(
512-
model, state.stage, state.epoch, print_to_channels=(state.rank == 0)
512+
model,
513+
self.optimizer,
514+
state.stage,
515+
state.epoch,
516+
print_to_channels=(state.rank == 0),
513517
)
514518
else:
515519
metric_reporter._reset()

0 commit comments

Comments
 (0)