Skip to content

Commit 4b7c6b4

Browse files
committed
Count Sentiment training on a batch basis instead of a number of items trained. Separate very long items into their own batches to avoid OOM errors, especially on the ZH dataset
1 parent 0a6d3cb commit 4b7c6b4

File tree

2 files changed

+25
-14
lines changed

2 files changed

+25
-14
lines changed

stanza/models/classifier.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def build_argparse():
173173
parser.add_argument('--test_file', type=str, default=DEFAULT_TEST, help='Input file(s) to use as the test set.')
174174
parser.add_argument('--output_predictions', default=False, action='store_true', help='Output predictions when running the test set')
175175
parser.add_argument('--max_epochs', type=int, default=100)
176-
parser.add_argument('--tick', type=int, default=2000)
176+
parser.add_argument('--tick', type=int, default=50)
177177

178178
parser.add_argument('--model_type', type=lambda x: ModelType[x.upper()], default=ModelType.CNN,
179179
help='Model type to use. Options: %s' % " ".join(x.name for x in ModelType))
@@ -184,7 +184,8 @@ def build_argparse():
184184
parser.add_argument('--dropout', default=0.5, type=float, help='Dropout value to use')
185185

186186
parser.add_argument('--batch_size', default=50, type=int, help='Batch size when training')
187-
parser.add_argument('--dev_eval_steps', default=100000, type=int, help='Run the dev set after this many train steps. Set to 0 to only do it once per epoch')
187+
parser.add_argument('--batch_single_item', default=200, type=int, help='Items of this size go in their own batch')
188+
parser.add_argument('--dev_eval_batches', default=2000, type=int, help='Run the dev set after this many train batches. Set to 0 to only do it once per epoch')
188189
parser.add_argument('--dev_eval_scoring', type=lambda x: DevScoring[x.upper()], default=DevScoring.WEIGHTED_F1,
189190
help=('Scoring method to use for choosing the best model. Options: %s' %
190191
" ".join(x.name for x in DevScoring)))
@@ -500,8 +501,6 @@ def train_model(trainer, model_file, checkpoint_file, args, train_set, dev_set,
500501
log_param_sizes(model)
501502

502503
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
503-
batch_starts = list(range(0, len(train_set), args.batch_size))
504-
505504
if args.wandb:
506505
import wandb
507506
wandb_name = args.wandb_name if args.wandb_name else "%s_classifier" % args.shorthand
@@ -513,19 +512,18 @@ def train_model(trainer, model_file, checkpoint_file, args, train_set, dev_set,
513512
for trainer.epochs_trained in range(trainer.epochs_trained, args.max_epochs):
514513
running_loss = 0.0
515514
epoch_loss = 0.0
516-
shuffled = data.shuffle_dataset(train_set_by_len)
515+
shuffled_batches = data.shuffle_dataset(train_set_by_len, args.batch_size, args.batch_single_item)
517516

518517
model.train()
519518
logger.info("Starting epoch %d", trainer.epochs_trained)
520519
if args.log_norms:
521520
model.log_norms()
522521

523-
random.shuffle(batch_starts)
524-
for batch_num, start_batch in enumerate(batch_starts):
522+
for batch_num, batch in enumerate(shuffled_batches):
523+
# logger.debug("Batch size %d max len %d" % (len(batch), max(len(x.text) for x in batch)))
525524
trainer.global_step += 1
526-
logger.debug("Starting batch: %d step %d", start_batch, trainer.global_step)
525+
logger.debug("Starting batch: %d step %d", batch_num, trainer.global_step)
527526

528-
batch = shuffled[start_batch:start_batch+args.batch_size]
529527
batch_labels = torch.stack([label_tensors[x.sentiment] for x in batch])
530528

531529
# zero the parameter gradients
@@ -539,12 +537,12 @@ def train_model(trainer, model_file, checkpoint_file, args, train_set, dev_set,
539537

540538
# print statistics
541539
running_loss += batch_loss.item()
542-
if ((batch_num + 1) * args.batch_size) % args.tick < args.batch_size: # print every 2000 items
540+
if (batch_num + 1) % args.tick == 0: # print every so many batches
543541
train_loss = running_loss / args.tick
544-
logger.info('[%d, %5d] Average loss: %.3f', trainer.epochs_trained + 1, (batch_num + 1) * args.batch_size, train_loss)
542+
logger.info('[%d, %5d] Average loss: %.3f', trainer.epochs_trained + 1, batch_num + 1, train_loss)
545543
if args.wandb:
546544
wandb.log({'train_loss': train_loss}, step=trainer.global_step)
547-
if args.dev_eval_steps > 0 and ((batch_num + 1) * args.batch_size) % args.dev_eval_steps < args.batch_size:
545+
if args.dev_eval_batches > 0 and (batch_num + 1) % args.dev_eval_batches == 0:
548546
logger.info('---- Interim analysis ----')
549547
dev_score, accuracy, macro_f1 = score_dev_set(model, dev_set, args.dev_eval_scoring)
550548
if args.wandb:

stanza/models/classifiers/data.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def sort_dataset_by_len(dataset, keep_index=False):
129129
sorted_dataset[len(item.text)].append(item)
130130
return sorted_dataset
131131

132-
def shuffle_dataset(sorted_dataset):
132+
def shuffle_dataset(sorted_dataset, batch_size, batch_single_item):
133133
"""
134134
Given a dataset sorted by len, sorts within each length to make
135135
chunks of roughly the same size. Returns all items as a single list.
@@ -139,7 +139,20 @@ def shuffle_dataset(sorted_dataset):
139139
items = list(sorted_dataset[l])
140140
random.shuffle(items)
141141
dataset.extend(items)
142-
return dataset
142+
batches = []
143+
next_batch = []
144+
for item in dataset:
145+
if batch_single_item > 0 and len(item.text) >= batch_single_item:
146+
batches.append([item])
147+
else:
148+
next_batch.append(item)
149+
if len(next_batch) >= batch_size:
150+
batches.append(next_batch)
151+
next_batch = []
152+
if len(next_batch) > 0:
153+
batches.append(next_batch)
154+
random.shuffle(batches)
155+
return batches
143156

144157

145158
def check_labels(labels, dataset):

0 commit comments

Comments
 (0)