Skip to content

Commit d5f7a40

Browse files
sliordeSvetlana Karslioglu
and
Svetlana Karslioglu
authored
rename confusing variable name in LM tutorial (#1953)
The tutorial "Language Modeling with nn.Transformer and TorchText" contains code snippets with variables named `batch_size`. The issue is that in some places, `batch_size` means that number of sequences in a batch, and in other places it means the number of tokens in each batch sequence. This inconsistency was solved in this commit: `batch_size` was replaced with `seq_len` in the two places where it has the latter meaning. Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent ae22720 commit d5f7a40

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

beginner_source/transformer_tutorial.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -297,9 +297,9 @@ def train(model: nn.Module) -> None:
297297
num_batches = len(train_data) // bptt
298298
for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
299299
data, targets = get_batch(train_data, i)
300-
batch_size = data.size(0)
301-
if batch_size != bptt: # only on last batch
302-
src_mask = src_mask[:batch_size, :batch_size]
300+
seq_len = data.size(0)
301+
if seq_len != bptt: # only on last batch
302+
src_mask = src_mask[:seq_len, :seq_len]
303303
output = model(data, src_mask)
304304
loss = criterion(output.view(-1, ntokens), targets)
305305

@@ -327,12 +327,12 @@ def evaluate(model: nn.Module, eval_data: Tensor) -> float:
327327
with torch.no_grad():
328328
for i in range(0, eval_data.size(0) - 1, bptt):
329329
data, targets = get_batch(eval_data, i)
330-
batch_size = data.size(0)
331-
if batch_size != bptt:
332-
src_mask = src_mask[:batch_size, :batch_size]
330+
seq_len = data.size(0)
331+
if seq_len != bptt:
332+
src_mask = src_mask[:seq_len, :seq_len]
333333
output = model(data, src_mask)
334334
output_flat = output.view(-1, ntokens)
335-
total_loss += batch_size * criterion(output_flat, targets).item()
335+
total_loss += seq_len * criterion(output_flat, targets).item()
336336
return total_loss / (len(eval_data) - 1)
337337

338338
######################################################################

0 commit comments

Comments
 (0)