Skip to content

Commit e64f9b1

Browse files
committed
Add a short test method that a single optimizer case saves checkpoints and the checkpoints are loadable
Add a flag which forces the optimizer to switch after a certain number of steps - useful for writing tests which check the behavior of the second optimizer
1 parent 8986b5b commit e64f9b1

File tree

2 files changed

+50
-7
lines changed

2 files changed

+50
-7
lines changed

stanza/models/parser.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def build_argparse():
9191
parser.add_argument('--lr', type=float, default=3e-3, help='Learning rate')
9292
parser.add_argument('--second_lr', type=float, default=3e-4, help='Secondary stage learning rate')
9393
parser.add_argument('--beta2', type=float, default=0.95)
94+
parser.add_argument('--second_optim_start_step', type=int, default=None, help='If set, switch to the second optimizer when stalled or at this step regardless of performance. Normally, the optimizer only switches when the dev scores have stalled for --max_steps_before_stop steps')
9495

9596
parser.add_argument('--max_steps', type=int, default=50000)
9697
parser.add_argument('--eval_interval', type=int, default=100)
@@ -129,7 +130,7 @@ def main(args=None):
129130
logger.info("Running parser in {} mode".format(args['mode']))
130131

131132
if args['mode'] == 'train':
132-
train(args)
133+
return train(args)
133134
else:
134135
evaluate(args)
135136

@@ -202,7 +203,8 @@ def train(args):
202203
checkpoint_file = None # used explicitly as the *PATH TO THE CHECKPOINT* could be None if we don't want to save chkpt
203204
if args.get("checkpoint"):
204205
model_to_load = utils.checkpoint_name(args.get("save_dir"), model_file, args.get("checkpoint_save_name"))
205-
checkpoint_file = copy.deepcopy(model_to_load)
206+
checkpoint_file = model_to_load
207+
args["checkpoint_save_name"] = checkpoint_file
206208
if args["continue_from"]:
207209
model_to_load = args["continue_from"]
208210

@@ -264,8 +266,8 @@ def train(args):
264266

265267
dev_score_history += [dev_score]
266268

267-
if global_step - last_best_step >= args['max_steps_before_stop']:
268-
if not is_second_stage and args.get('second_optim', None) is not None:
269+
if not is_second_stage and args.get('second_optim', None) is not None:
270+
if global_step - last_best_step >= args['max_steps_before_stop'] or (args['second_optim_start_step'] is not None and global_step >= args['second_optim_start_step']):
269271
logger.info("Switching to second optimizer: {}".format(args.get('second_optim', None)))
270272
args["second_stage"] = True
271273
# if the loader gets a model file, it uses secondary optimizer
@@ -274,7 +276,8 @@ def train(args):
274276
logger.info('Reloading best model to continue from current local optimum')
275277
is_second_stage = True
276278
last_best_step = global_step
277-
else:
279+
else:
280+
if global_step - last_best_step >= args['max_steps_before_stop']:
278281
do_break = True
279282
break
280283

@@ -306,6 +309,7 @@ def train(args):
306309
logger.info("Dev set never evaluated. Saving final model.")
307310
trainer.save(model_file)
308311

312+
return trainer
309313

310314
def evaluate(args):
311315
# file paths

stanza/tests/depparse/test_parser.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import os
88
import pytest
99

10+
import torch
11+
1012
from stanza.models import parser
1113
from stanza.models.common import pretrain
1214
from stanza.models.depparse.trainer import Trainer
@@ -108,12 +110,13 @@ def run_training(self, tmp_path, wordvec_pretrain_file, train_text, dev_text, au
108110
args.extend(["--augment_nopunct", "0.0"])
109111
if extra_args is not None:
110112
args = args + extra_args
111-
parser.main(args)
113+
trainer = parser.main(args)
112114

113115
assert os.path.exists(save_file)
114116
pt = pretrain.Pretrain(wordvec_pretrain_file)
117+
# test loading the saved model
115118
saved_model = Trainer(pretrain=pt, model_file=save_file)
116-
return saved_model
119+
return trainer
117120

118121
def test_train(self, tmp_path, wordvec_pretrain_file):
119122
"""
@@ -127,3 +130,39 @@ def test_with_bert(self, tmp_path, wordvec_pretrain_file):
127130
def test_with_bert_nlayers(self, tmp_path, wordvec_pretrain_file):
128131
self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--bert_model', 'hf-internal-testing/tiny-bert', '--bert_hidden_layers', '2'])
129132

133+
def test_single_optimizer_checkpoint(self, tmp_path, wordvec_pretrain_file):
134+
trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--optim', 'adam'])
135+
136+
save_dir = trainer.args['save_dir']
137+
save_name = trainer.args['save_name']
138+
checkpoint_name = trainer.args["checkpoint_save_name"]
139+
140+
assert os.path.exists(os.path.join(save_dir, save_name))
141+
assert checkpoint_name is not None
142+
assert os.path.exists(checkpoint_name)
143+
144+
assert isinstance(trainer.optimizer, torch.optim.Adam)
145+
146+
pt = pretrain.Pretrain(wordvec_pretrain_file)
147+
checkpoint = Trainer(args=trainer.args, pretrain=pt, model_file=checkpoint_name)
148+
assert checkpoint.optimizer is not None
149+
assert isinstance(checkpoint.optimizer, torch.optim.Adam)
150+
151+
def test_two_optimizers_checkpoint(self, tmp_path, wordvec_pretrain_file):
152+
trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=['--optim', 'adam', '--second_optim', 'sgd', '--second_optim_start_step', '40'])
153+
154+
save_dir = trainer.args['save_dir']
155+
save_name = trainer.args['save_name']
156+
checkpoint_name = trainer.args["checkpoint_save_name"]
157+
158+
assert os.path.exists(os.path.join(save_dir, save_name))
159+
assert checkpoint_name is not None
160+
assert os.path.exists(checkpoint_name)
161+
162+
assert isinstance(trainer.optimizer, torch.optim.SGD)
163+
164+
pt = pretrain.Pretrain(wordvec_pretrain_file)
165+
checkpoint = Trainer(args=trainer.args, pretrain=pt, model_file=checkpoint_name)
166+
assert checkpoint.optimizer is not None
167+
assert isinstance(checkpoint.optimizer, torch.optim.SGD)
168+

0 commit comments

Comments
 (0)