Skip to content

Commit 8986b5b

Browse files
JemokaAngledLuffa
authored andcommitted
Implements 2 stage optimization for dependency parsing
1) save optimizer in checkpoint 2) two stage tracking using args Save the checkpoints after switching the optimizer, if applicable, so that reloading uses the new optimizer once it has been created
1 parent 809de1a commit 8986b5b

File tree

2 files changed

+81
-14
lines changed

2 files changed

+81
-14
lines changed

stanza/models/depparse/trainer.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
A trainer class to handle training and testing of models.
33
"""
44

5+
import copy
56
import sys
67
import logging
78
import torch
@@ -28,17 +29,42 @@ def unpack_batch(batch, device):
2829

2930
class Trainer(BaseTrainer):
3031
""" A trainer for training models. """
31-
def __init__(self, args=None, vocab=None, pretrain=None, model_file=None, device=None, foundation_cache=None):
32+
def __init__(self, args=None, vocab=None, pretrain=None, model_file=None,
33+
device=None, foundation_cache=None, ignore_model_config=False):
34+
orig_args = copy.deepcopy(args)
35+
# whether the training is in primary or secondary stage
36+
# during FT (loading weights), etc., the training is considered to be in "secondary stage"
37+
# during this time, we (optionally) use a different set of optimizers than that during "primary stage".
38+
#
39+
# Regardless, we use TWO SETS of optimizers; once primary converges, we switch to secondary
3240
if model_file is not None:
3341
# load everything from file
34-
self.load(model_file, pretrain, args, foundation_cache)
42+
self.load(model_file, pretrain, args, foundation_cache, device)
3543
else:
3644
# build model from scratch
3745
self.args = args
3846
self.vocab = vocab
3947
self.model = Parser(args, vocab, emb_matrix=pretrain.emb if pretrain is not None else None)
40-
self.model = self.model.to(device)
41-
self.optimizer = utils.get_optimizer(self.args['optim'], self.model, self.args['lr'], betas=(0.9, self.args['beta2']), eps=1e-6, bert_learning_rate=self.args.get('bert_learning_rate', 0.0))
48+
self.model = self.model.to(device)
49+
self.__init_optim()
50+
51+
if ignore_model_config:
52+
self.args = orig_args
53+
54+
if self.args.get('wandb'):
55+
import wandb
56+
# track gradients!
57+
wandb.watch(self.model, log_freq=4, log="all", log_graph=True)
58+
59+
def __init_optim(self):
60+
if (self.args.get("second_stage", False) and self.args.get('second_optim')):
61+
self.optimizer = utils.get_optimizer(self.args['second_optim'], self.model,
62+
self.args['second_lr'], betas=(0.9, self.args['beta2']), eps=1e-6,
63+
bert_learning_rate=self.args.get('second_bert_learning_rate', 0.0))
64+
else:
65+
self.optimizer = utils.get_optimizer(self.args['optim'], self.model,
66+
self.args['lr'], betas=(0.9, self.args['beta2']),
67+
eps=1e-6, bert_learning_rate=self.args.get('bert_learning_rate', 0.0))
4268

4369
def update(self, batch, eval=False):
4470
device = next(self.model.parameters()).device
@@ -76,7 +102,7 @@ def predict(self, batch, unsort=True):
76102
pred_tokens = utils.unsort(pred_tokens, orig_idx)
77103
return pred_tokens
78104

79-
def save(self, filename, skip_modules=True):
105+
def save(self, filename, skip_modules=True, save_optimizer=False):
80106
model_state = self.model.state_dict()
81107
# skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file
82108
if skip_modules:
@@ -88,13 +114,17 @@ def save(self, filename, skip_modules=True):
88114
'vocab': self.vocab.state_dict(),
89115
'config': self.args
90116
}
117+
118+
if save_optimizer and self.optimizer is not None:
119+
params['optimizer_state_dict'] = self.optimizer.state_dict()
120+
91121
try:
92122
torch.save(params, filename, _use_new_zipfile_serialization=False)
93123
logger.info("Model saved to {}".format(filename))
94124
except BaseException:
95125
logger.warning("Saving failed... continuing anyway.")
96126

97-
def load(self, filename, pretrain, args=None, foundation_cache=None):
127+
def load(self, filename, pretrain, args=None, foundation_cache=None, device=None):
98128
"""
99129
Load a model from file, with preloaded pretrain embeddings. Here we allow the pretrain to be None or a dummy input,
100130
and the actual use of pretrain embeddings will depend on the boolean config "pretrain" in the loaded args.
@@ -119,4 +149,11 @@ def load(self, filename, pretrain, args=None, foundation_cache=None):
119149
foundation_cache = NoTransformerFoundationCache(foundation_cache)
120150
self.model = Parser(self.args, self.vocab, emb_matrix=emb_matrix, foundation_cache=foundation_cache)
121151
self.model.load_state_dict(checkpoint['model'], strict=False)
152+
if device is not None:
153+
self.model = self.model.to(device)
154+
155+
self.__init_optim()
156+
optim_state_dict = checkpoint.get("optimizer_state_dict")
157+
if optim_state_dict:
158+
self.optimizer.load_state_dict(optim_state_dict)
122159

stanza/models/parser.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import sys
1313
import os
14+
import copy
1415
import shutil
1516
import time
1617
import argparse
@@ -43,7 +44,6 @@ def build_argparse():
4344
parser.add_argument('--eval_file', type=str, default=None, help='Input file for data loader.')
4445
parser.add_argument('--output_file', type=str, default=None, help='Output CoNLL-U file.')
4546
parser.add_argument('--gold_file', type=str, default=None, help='Output CoNLL-U file.')
46-
4747
parser.add_argument('--mode', default='train', choices=['train', 'predict'])
4848
parser.add_argument('--lang', type=str, help='Language')
4949
parser.add_argument('--shorthand', type=str, help="Treebank shorthand")
@@ -58,6 +58,8 @@ def build_argparse():
5858
parser.add_argument('--transformed_dim', type=int, default=125)
5959
parser.add_argument('--num_layers', type=int, default=3)
6060
parser.add_argument('--char_num_layers', type=int, default=1)
61+
parser.add_argument('--checkpoint_save_name', type=str, default=None, help="File name to save the most recent checkpoint")
62+
parser.add_argument('--no_checkpoint', dest='checkpoint', action='store_false', help="Don't save checkpoints")
6163
parser.add_argument('--pretrain_max_vocab', type=int, default=250000)
6264
parser.add_argument('--word_dropout', type=float, default=0.33)
6365
parser.add_argument('--dropout', type=float, default=0.5)
@@ -77,24 +79,28 @@ def build_argparse():
7779
parser.add_argument('--bert_finetune', default=False, action='store_true', help='Finetune the bert (or other transformer)')
7880
parser.add_argument('--no_bert_finetune', dest='bert_finetune', action='store_false', help="Don't finetune the bert (or other transformer)")
7981
parser.add_argument('--bert_learning_rate', default=1.0, type=float, help='Scale the learning rate for transformer finetuning by this much')
82+
parser.add_argument('--second_bert_learning_rate', default=1e-3, type=float, help='Secondary stage transformer finetuning learning rate scale')
8083

8184
parser.add_argument('--no_pretrain', dest='pretrain', action='store_false', help="Turn off pretrained embeddings.")
8285
parser.add_argument('--no_linearization', dest='linearization', action='store_false', help="Turn off linearization term.")
8386
parser.add_argument('--no_distance', dest='distance', action='store_false', help="Turn off distance term.")
8487

8588
parser.add_argument('--sample_train', type=float, default=1.0, help='Subsample training data.')
8689
parser.add_argument('--optim', type=str, default='adam', help='sgd, adagrad, adam or adamax.')
90+
parser.add_argument('--second_optim', type=str, default=None, help='sgd, adagrad, adam or adamax.')
8791
parser.add_argument('--lr', type=float, default=3e-3, help='Learning rate')
92+
parser.add_argument('--second_lr', type=float, default=3e-4, help='Secondary stage learning rate')
8893
parser.add_argument('--beta2', type=float, default=0.95)
8994

9095
parser.add_argument('--max_steps', type=int, default=50000)
9196
parser.add_argument('--eval_interval', type=int, default=100)
92-
parser.add_argument('--max_steps_before_stop', type=int, default=3000)
97+
parser.add_argument('--max_steps_before_stop', type=int, default=1000)
9398
parser.add_argument('--batch_size', type=int, default=5000)
9499
parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Gradient clipping.')
95100
parser.add_argument('--log_step', type=int, default=20, help='Print log every k steps.')
96101
parser.add_argument('--save_dir', type=str, default='saved_models/depparse', help='Root dir for saving models.')
97102
parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_parser.pt", help="File name to save the model")
103+
parser.add_argument('--continue_from', type=str, default=None, help="File name to preload the model to continue training from")
98104

99105
parser.add_argument('--seed', type=int, default=1234)
100106
utils.add_device_args(parser)
@@ -191,7 +197,19 @@ def train(args):
191197
wandb.run.define_metric('dev_score', summary='max')
192198

193199
logger.info("Training parser...")
194-
trainer = Trainer(args=args, vocab=vocab, pretrain=pretrain, device=args['device'])
200+
# calculate checkpoint file name and the sav
201+
model_to_load = None # used for general loading and reloading
202+
checkpoint_file = None # used explicitly as the *PATH TO THE CHECKPOINT* could be None if we don't want to save chkpt
203+
if args.get("checkpoint"):
204+
model_to_load = utils.checkpoint_name(args.get("save_dir"), model_file, args.get("checkpoint_save_name"))
205+
checkpoint_file = copy.deepcopy(model_to_load)
206+
if args["continue_from"]:
207+
model_to_load = args["continue_from"]
208+
209+
if model_to_load is not None and os.path.exists(model_to_load):
210+
trainer = Trainer(args=args, pretrain=pretrain, vocab=vocab, model_file=model_to_load, device=args['device'], ignore_model_config=True)
211+
else:
212+
trainer = Trainer(args=args, vocab=vocab, pretrain=pretrain, device=args['device'])
195213

196214
global_step = 0
197215
max_steps = args['max_steps']
@@ -201,7 +219,7 @@ def train(args):
201219
global_start_time = time.time()
202220
format_str = 'Finished STEP {}/{}, loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'
203221

204-
using_amsgrad = False
222+
is_second_stage = False
205223
last_best_step = 0
206224
# start training
207225
train_loss = 0
@@ -247,15 +265,27 @@ def train(args):
247265
dev_score_history += [dev_score]
248266

249267
if global_step - last_best_step >= args['max_steps_before_stop']:
250-
if not using_amsgrad:
251-
logger.info("Switching to AMSGrad")
268+
if not is_second_stage and args.get('second_optim', None) is not None:
269+
logger.info("Switching to second optimizer: {}".format(args.get('second_optim', None)))
270+
args["second_stage"] = True
271+
# if the loader gets a model file, it uses secondary optimizer
272+
trainer = Trainer(args=args, vocab=trainer.vocab, pretrain=pretrain,
273+
model_file=model_file, device=args['device'])
274+
logger.info('Reloading best model to continue from current local optimum')
275+
is_second_stage = True
252276
last_best_step = global_step
253-
using_amsgrad = True
254-
trainer.optimizer = optim.Adam(trainer.model.parameters(), amsgrad=True, lr=args['lr'], betas=(.9, args['beta2']), eps=1e-6)
255277
else:
256278
do_break = True
257279
break
258280

281+
if global_step % args['eval_interval'] == 0:
282+
# if we need to save checkpoint, do so
283+
# (save after switching the optimizer, if applicable, so that
284+
# the new optimizer is the optimizer used if a restart happens)
285+
if checkpoint_file is not None:
286+
trainer.save(checkpoint_file, save_optimizer=True)
287+
logger.info("new model checkpoint saved.")
288+
259289
if global_step >= args['max_steps']:
260290
do_break = True
261291
break

0 commit comments

Comments
 (0)