11
11
12
12
import sys
13
13
import os
14
+ import copy
14
15
import shutil
15
16
import time
16
17
import argparse
@@ -43,7 +44,6 @@ def build_argparse():
43
44
parser .add_argument ('--eval_file' , type = str , default = None , help = 'Input file for data loader.' )
44
45
parser .add_argument ('--output_file' , type = str , default = None , help = 'Output CoNLL-U file.' )
45
46
parser .add_argument ('--gold_file' , type = str , default = None , help = 'Output CoNLL-U file.' )
46
-
47
47
parser .add_argument ('--mode' , default = 'train' , choices = ['train' , 'predict' ])
48
48
parser .add_argument ('--lang' , type = str , help = 'Language' )
49
49
parser .add_argument ('--shorthand' , type = str , help = "Treebank shorthand" )
@@ -58,6 +58,8 @@ def build_argparse():
58
58
parser .add_argument ('--transformed_dim' , type = int , default = 125 )
59
59
parser .add_argument ('--num_layers' , type = int , default = 3 )
60
60
parser .add_argument ('--char_num_layers' , type = int , default = 1 )
61
+ parser .add_argument ('--checkpoint_save_name' , type = str , default = None , help = "File name to save the most recent checkpoint" )
62
+ parser .add_argument ('--no_checkpoint' , dest = 'checkpoint' , action = 'store_false' , help = "Don't save checkpoints" )
61
63
parser .add_argument ('--pretrain_max_vocab' , type = int , default = 250000 )
62
64
parser .add_argument ('--word_dropout' , type = float , default = 0.33 )
63
65
parser .add_argument ('--dropout' , type = float , default = 0.5 )
@@ -77,24 +79,28 @@ def build_argparse():
77
79
parser .add_argument ('--bert_finetune' , default = False , action = 'store_true' , help = 'Finetune the bert (or other transformer)' )
78
80
parser .add_argument ('--no_bert_finetune' , dest = 'bert_finetune' , action = 'store_false' , help = "Don't finetune the bert (or other transformer)" )
79
81
parser .add_argument ('--bert_learning_rate' , default = 1.0 , type = float , help = 'Scale the learning rate for transformer finetuning by this much' )
82
+ parser .add_argument ('--second_bert_learning_rate' , default = 1e-3 , type = float , help = 'Secondary stage transformer finetuning learning rate scale' )
80
83
81
84
parser .add_argument ('--no_pretrain' , dest = 'pretrain' , action = 'store_false' , help = "Turn off pretrained embeddings." )
82
85
parser .add_argument ('--no_linearization' , dest = 'linearization' , action = 'store_false' , help = "Turn off linearization term." )
83
86
parser .add_argument ('--no_distance' , dest = 'distance' , action = 'store_false' , help = "Turn off distance term." )
84
87
85
88
parser .add_argument ('--sample_train' , type = float , default = 1.0 , help = 'Subsample training data.' )
86
89
parser .add_argument ('--optim' , type = str , default = 'adam' , help = 'sgd, adagrad, adam or adamax.' )
90
+ parser .add_argument ('--second_optim' , type = str , default = None , help = 'sgd, adagrad, adam or adamax.' )
87
91
parser .add_argument ('--lr' , type = float , default = 3e-3 , help = 'Learning rate' )
92
+ parser .add_argument ('--second_lr' , type = float , default = 3e-4 , help = 'Secondary stage learning rate' )
88
93
parser .add_argument ('--beta2' , type = float , default = 0.95 )
89
94
90
95
parser .add_argument ('--max_steps' , type = int , default = 50000 )
91
96
parser .add_argument ('--eval_interval' , type = int , default = 100 )
92
- parser .add_argument ('--max_steps_before_stop' , type = int , default = 3000 )
97
+ parser .add_argument ('--max_steps_before_stop' , type = int , default = 1000 )
93
98
parser .add_argument ('--batch_size' , type = int , default = 5000 )
94
99
parser .add_argument ('--max_grad_norm' , type = float , default = 1.0 , help = 'Gradient clipping.' )
95
100
parser .add_argument ('--log_step' , type = int , default = 20 , help = 'Print log every k steps.' )
96
101
parser .add_argument ('--save_dir' , type = str , default = 'saved_models/depparse' , help = 'Root dir for saving models.' )
97
102
parser .add_argument ('--save_name' , type = str , default = "{shorthand}_{embedding}_parser.pt" , help = "File name to save the model" )
103
+ parser .add_argument ('--continue_from' , type = str , default = None , help = "File name to preload the model to continue training from" )
98
104
99
105
parser .add_argument ('--seed' , type = int , default = 1234 )
100
106
utils .add_device_args (parser )
@@ -191,7 +197,19 @@ def train(args):
191
197
wandb .run .define_metric ('dev_score' , summary = 'max' )
192
198
193
199
logger .info ("Training parser..." )
194
- trainer = Trainer (args = args , vocab = vocab , pretrain = pretrain , device = args ['device' ])
200
+ # calculate checkpoint file name and the sav
201
+ model_to_load = None # used for general loading and reloading
202
+ checkpoint_file = None # used explicitly as the *PATH TO THE CHECKPOINT* could be None if we don't want to save chkpt
203
+ if args .get ("checkpoint" ):
204
+ model_to_load = utils .checkpoint_name (args .get ("save_dir" ), model_file , args .get ("checkpoint_save_name" ))
205
+ checkpoint_file = copy .deepcopy (model_to_load )
206
+ if args ["continue_from" ]:
207
+ model_to_load = args ["continue_from" ]
208
+
209
+ if model_to_load is not None and os .path .exists (model_to_load ):
210
+ trainer = Trainer (args = args , pretrain = pretrain , vocab = vocab , model_file = model_to_load , device = args ['device' ], ignore_model_config = True )
211
+ else :
212
+ trainer = Trainer (args = args , vocab = vocab , pretrain = pretrain , device = args ['device' ])
195
213
196
214
global_step = 0
197
215
max_steps = args ['max_steps' ]
@@ -201,7 +219,7 @@ def train(args):
201
219
global_start_time = time .time ()
202
220
format_str = 'Finished STEP {}/{}, loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'
203
221
204
- using_amsgrad = False
222
+ is_second_stage = False
205
223
last_best_step = 0
206
224
# start training
207
225
train_loss = 0
@@ -247,15 +265,27 @@ def train(args):
247
265
dev_score_history += [dev_score ]
248
266
249
267
if global_step - last_best_step >= args ['max_steps_before_stop' ]:
250
- if not using_amsgrad :
251
- logger .info ("Switching to AMSGrad" )
268
+ if not is_second_stage and args .get ('second_optim' , None ) is not None :
269
+ logger .info ("Switching to second optimizer: {}" .format (args .get ('second_optim' , None )))
270
+ args ["second_stage" ] = True
271
+ # if the loader gets a model file, it uses secondary optimizer
272
+ trainer = Trainer (args = args , vocab = trainer .vocab , pretrain = pretrain ,
273
+ model_file = model_file , device = args ['device' ])
274
+ logger .info ('Reloading best model to continue from current local optimum' )
275
+ is_second_stage = True
252
276
last_best_step = global_step
253
- using_amsgrad = True
254
- trainer .optimizer = optim .Adam (trainer .model .parameters (), amsgrad = True , lr = args ['lr' ], betas = (.9 , args ['beta2' ]), eps = 1e-6 )
255
277
else :
256
278
do_break = True
257
279
break
258
280
281
+ if global_step % args ['eval_interval' ] == 0 :
282
+ # if we need to save checkpoint, do so
283
+ # (save after switching the optimizer, if applicable, so that
284
+ # the new optimizer is the optimizer used if a restart happens)
285
+ if checkpoint_file is not None :
286
+ trainer .save (checkpoint_file , save_optimizer = True )
287
+ logger .info ("new model checkpoint saved." )
288
+
259
289
if global_step >= args ['max_steps' ]:
260
290
do_break = True
261
291
break
0 commit comments