Skip to content

Commit 5bc22dd

Browse files
committed
Put the global_step and dev score history into the model files so that when a checkpoint gets loaded, the training continues from the position it was formerly at rather than restarting from 0
Report some details of the model being loaded after loading it
1 parent e64f9b1 commit 5bc22dd

File tree

2 files changed

+47
-32
lines changed

2 files changed

+47
-32
lines changed

stanza/models/depparse/trainer.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,26 @@ def unpack_batch(batch, device):
3030
class Trainer(BaseTrainer):
3131
""" A trainer for training models. """
3232
def __init__(self, args=None, vocab=None, pretrain=None, model_file=None,
33-
device=None, foundation_cache=None, ignore_model_config=False):
33+
device=None, foundation_cache=None, ignore_model_config=False, reset_history=False):
34+
self.global_step = 0
35+
self.last_best_step = 0
36+
self.dev_score_history = []
37+
3438
orig_args = copy.deepcopy(args)
3539
# whether the training is in primary or secondary stage
3640
# during FT (loading weights), etc., the training is considered to be in "secondary stage"
3741
# during this time, we (optionally) use a different set of optimizers than that during "primary stage".
3842
#
3943
# Regardless, we use TWO SETS of optimizers; once primary converges, we switch to secondary
44+
4045
if model_file is not None:
4146
# load everything from file
4247
self.load(model_file, pretrain, args, foundation_cache, device)
48+
49+
if reset_history:
50+
self.global_step = 0
51+
self.last_best_step = 0
52+
self.dev_score_history = []
4353
else:
4454
# build model from scratch
4555
self.args = args
@@ -112,7 +122,10 @@ def save(self, filename, skip_modules=True, save_optimizer=False):
112122
params = {
113123
'model': model_state,
114124
'vocab': self.vocab.state_dict(),
115-
'config': self.args
125+
'config': self.args,
126+
'global_step': self.global_step,
127+
'last_best_step': self.last_best_step,
128+
'dev_score_history': self.dev_score_history,
116129
}
117130

118131
if save_optimizer and self.optimizer is not None:
@@ -157,3 +170,6 @@ def load(self, filename, pretrain, args=None, foundation_cache=None, device=None
157170
if optim_state_dict:
158171
self.optimizer.load_state_dict(optim_state_dict)
159172

173+
self.global_step = checkpoint.get("global_step", 0)
174+
self.last_best_step = checkpoint.get("last_best_step", 0)
175+
self.dev_score_history = checkpoint.get("dev_score_history", list())

stanza/models/parser.py

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -198,45 +198,42 @@ def train(args):
198198
wandb.run.define_metric('dev_score', summary='max')
199199

200200
logger.info("Training parser...")
201-
# calculate checkpoint file name and the sav
202-
model_to_load = None # used for general loading and reloading
203-
checkpoint_file = None # used explicitly as the *PATH TO THE CHECKPOINT* could be None if we don't want to save chkpt
204201
if args.get("checkpoint"):
205-
model_to_load = utils.checkpoint_name(args.get("save_dir"), model_file, args.get("checkpoint_save_name"))
206-
checkpoint_file = model_to_load
202+
# calculate checkpoint file name from the save filename
203+
checkpoint_file = utils.checkpoint_name(args.get("save_dir"), model_file, args.get("checkpoint_save_name"))
207204
args["checkpoint_save_name"] = checkpoint_file
208-
if args["continue_from"]:
209-
model_to_load = args["continue_from"]
210205

211-
if model_to_load is not None and os.path.exists(model_to_load):
212-
trainer = Trainer(args=args, pretrain=pretrain, vocab=vocab, model_file=model_to_load, device=args['device'], ignore_model_config=True)
206+
if args.get("checkpoint") and os.path.exists(args["checkpoint_save_name"]):
207+
trainer = Trainer(args=args, pretrain=pretrain, vocab=vocab, model_file=args["checkpoint_save_name"], device=args['device'], ignore_model_config=True)
208+
if len(trainer.dev_score_history) > 0:
209+
logger.info("Continuing from checkpoint %s Model was previously trained for %d steps, with a best dev score of %.4f", args["checkpoint_save_name"], trainer.global_step, max(trainer.dev_score_history))
210+
elif args["continue_from"]:
211+
if not os.path.exists(args["continue_from"]):
212+
raise FileNotFoundError("--continue_from specified, but the file %s does not exist" % args["continue_from"])
213+
trainer = Trainer(args=args, pretrain=pretrain, vocab=vocab, model_file=args["continue_from"], device=args['device'], ignore_model_config=True, reset_history=True)
213214
else:
214215
trainer = Trainer(args=args, vocab=vocab, pretrain=pretrain, device=args['device'])
215216

216-
global_step = 0
217217
max_steps = args['max_steps']
218-
dev_score_history = []
219-
best_dev_preds = []
220218
current_lr = args['lr']
221219
global_start_time = time.time()
222220
format_str = 'Finished STEP {}/{}, loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'
223221

224222
is_second_stage = False
225-
last_best_step = 0
226223
# start training
227224
train_loss = 0
228225
while True:
229226
do_break = False
230227
for i, batch in enumerate(train_batch):
231228
start_time = time.time()
232-
global_step += 1
229+
trainer.global_step += 1
233230
loss = trainer.update(batch, eval=False) # update step
234231
train_loss += loss
235-
if global_step % args['log_step'] == 0:
232+
if trainer.global_step % args['log_step'] == 0:
236233
duration = time.time() - start_time
237-
logger.info(format_str.format(global_step, max_steps, loss, duration, current_lr))
234+
logger.info(format_str.format(trainer.global_step, max_steps, loss, duration, current_lr))
238235

239-
if global_step % args['eval_interval'] == 0:
236+
if trainer.global_step % args['eval_interval'] == 0:
240237
# eval on dev
241238
logger.info("Evaluating on dev set...")
242239
dev_preds = []
@@ -250,60 +247,62 @@ def train(args):
250247
_, _, dev_score = scorer.score(system_pred_file, gold_file)
251248

252249
train_loss = train_loss / args['eval_interval'] # avg loss per batch
253-
logger.info("step {}: train_loss = {:.6f}, dev_score = {:.4f}".format(global_step, train_loss, dev_score))
250+
logger.info("step {}: train_loss = {:.6f}, dev_score = {:.4f}".format(trainer.global_step, train_loss, dev_score))
254251

255252
if args['wandb']:
256253
wandb.log({'train_loss': train_loss, 'dev_score': dev_score})
257254

258255
train_loss = 0
259256

260257
# save best model
261-
if len(dev_score_history) == 0 or dev_score > max(dev_score_history):
262-
last_best_step = global_step
258+
if len(trainer.dev_score_history) == 0 or dev_score > max(trainer.dev_score_history):
259+
trainer.last_best_step = trainer.global_step
263260
trainer.save(model_file)
264261
logger.info("new best model saved.")
265-
best_dev_preds = dev_preds
266262

267-
dev_score_history += [dev_score]
263+
trainer.dev_score_history += [dev_score]
268264

269265
if not is_second_stage and args.get('second_optim', None) is not None:
270-
if global_step - last_best_step >= args['max_steps_before_stop'] or (args['second_optim_start_step'] is not None and global_step >= args['second_optim_start_step']):
266+
if trainer.global_step - trainer.last_best_step >= args['max_steps_before_stop'] or (args['second_optim_start_step'] is not None and trainer.global_step >= args['second_optim_start_step']):
271267
logger.info("Switching to second optimizer: {}".format(args.get('second_optim', None)))
272268
args["second_stage"] = True
273269
# if the loader gets a model file, it uses secondary optimizer
274270
trainer = Trainer(args=args, vocab=trainer.vocab, pretrain=pretrain,
275271
model_file=model_file, device=args['device'])
276272
logger.info('Reloading best model to continue from current local optimum')
277273
is_second_stage = True
278-
last_best_step = global_step
274+
trainer.last_best_step = trainer.global_step
279275
else:
280-
if global_step - last_best_step >= args['max_steps_before_stop']:
276+
if trainer.global_step - trainer.last_best_step >= args['max_steps_before_stop']:
281277
do_break = True
282278
break
283279

284-
if global_step % args['eval_interval'] == 0:
280+
if trainer.global_step % args['eval_interval'] == 0:
285281
# if we need to save checkpoint, do so
286282
# (save after switching the optimizer, if applicable, so that
287283
# the new optimizer is the optimizer used if a restart happens)
288284
if checkpoint_file is not None:
289285
trainer.save(checkpoint_file, save_optimizer=True)
290286
logger.info("new model checkpoint saved.")
291287

292-
if global_step >= args['max_steps']:
288+
if trainer.global_step >= args['max_steps']:
293289
do_break = True
294290
break
295291

296292
if do_break: break
297293

298294
train_batch.reshuffle()
299295

300-
logger.info("Training ended with {} steps.".format(global_step))
296+
logger.info("Training ended with {} steps.".format(trainer.global_step))
301297

302298
if args['wandb']:
303299
wandb.finish()
304300

305-
if len(dev_score_history) > 0:
306-
best_f, best_eval = max(dev_score_history)*100, np.argmax(dev_score_history)+1
301+
if len(trainer.dev_score_history) > 0:
302+
# TODO: technically the iteration position will be wrong if
303+
# the eval_interval changed when running from a checkpoint
304+
# could fix this by saving step & score instead of just score
305+
best_f, best_eval = max(trainer.dev_score_history)*100, np.argmax(trainer.dev_score_history)+1
307306
logger.info("Best dev F1 = {:.2f}, at iteration = {}".format(best_f, best_eval * args['eval_interval']))
308307
else:
309308
logger.info("Dev set never evaluated. Saving final model.")

0 commit comments

Comments
 (0)