Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit 89a45ec

Browse files
Michael Wufacebook-github-bot
authored andcommitted
Always create non-empty logits / targets for MLM (#979)
Summary: Pull Request resolved: #979 D17241503 selects only masked tokens for the final logits / targets during MLM. This fails when there are no masked tokens (e.g. at the end of a file there can be a very short batch). In this case, select just the first token in the first batch. BTW, the new masking strategy is faster (20%) than the old one - f138005051 vs f138005044. Reviewed By: borguz Differential Revision: D17370855 fbshipit-source-id: 62f0540fb94819c6a269ea067b1c0b9e08c82119
1 parent c2c31ac commit 89a45ec

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

pytext/metric_reporters/language_model_metric_reporter.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def report_realtime_metric(self, stage):
230230
print(
231231
f"Tokens/s: {last_batch_tps:.0f}, "
232232
f"batch ppl: {math.exp(last_batch_loss):.2f}, "
233-
f"agg ppl: {math.exp(aggregate_loss / float(total_masked_tokens)):.2f}, "
233+
f"agg ppl: {math.exp(self._calculate_loss(aggregate_loss, total_masked_tokens)):.2f}, "
234234
f"number of batches: {self.total_batches:.0f}, "
235235
f"accumulated tokens/s: {tps:.0f}",
236236
flush=True,
@@ -239,7 +239,7 @@ def report_realtime_metric(self, stage):
239239
print(
240240
f"GPU-0 tokens/s: {self.last_batch_tps:.0f}, "
241241
f"batch ppl: {math.exp(self.last_batch_loss):.2f}, "
242-
f"agg ppl: {math.exp(self.aggregate_loss / float(self.total_masked_tokens)):.2f}, "
242+
f"agg ppl: {math.exp(self.calculate_loss()):.2f}, "
243243
f"number of batches: {self.total_batches}, "
244244
f"accumulated tokens/s: {self.realtime_meters['tps'].avg:.0f}",
245245
flush=True,
@@ -261,7 +261,10 @@ def report_realtime_metric(self, stage):
261261
)
262262

263263
def calculate_loss(self) -> float:
264-
return self.aggregate_loss / float(self.total_masked_tokens)
264+
return self._calculate_loss(self.aggregate_loss, self.total_masked_tokens)
265+
266+
def _calculate_loss(self, aggregate_loss, total_masked_tokens) -> float:
267+
return aggregate_loss / max(1, total_masked_tokens)
265268

266269
def _reset(self):
267270
super()._reset()

pytext/models/masked_lm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,10 @@ def _get_mask(self, tokens):
166166
mask = self._select_tokens_to_mask(tokens, self.mask_prob)
167167
pad_mask = (tokens != self.vocab.get_pad_index()).long()
168168
mask *= pad_mask
169+
if not mask.byte().any():
170+
# Keep one masked token to avoid failure in the loss calculation.
171+
mask[0, 0] = 1
172+
169173
probs = torch.rand_like(tokens, dtype=torch.float)
170174
rand_mask = (probs < 0.1).long() * mask
171175
mask_mask = (probs >= 0.2).long() * mask

0 commit comments

Comments
 (0)