@@ -198,45 +198,42 @@ def train(args):
198
198
wandb .run .define_metric ('dev_score' , summary = 'max' )
199
199
200
200
logger .info ("Training parser..." )
201
- # calculate checkpoint file name and the sav
202
- model_to_load = None # used for general loading and reloading
203
- checkpoint_file = None # used explicitly as the *PATH TO THE CHECKPOINT* could be None if we don't want to save chkpt
204
201
if args .get ("checkpoint" ):
205
- model_to_load = utils . checkpoint_name ( args . get ( "save_dir" ), model_file , args . get ( "checkpoint_save_name" ))
206
- checkpoint_file = model_to_load
202
+ # calculate checkpoint file name from the save filename
203
+ checkpoint_file = utils . checkpoint_name ( args . get ( "save_dir" ), model_file , args . get ( "checkpoint_save_name" ))
207
204
args ["checkpoint_save_name" ] = checkpoint_file
208
- if args ["continue_from" ]:
209
- model_to_load = args ["continue_from" ]
210
205
211
- if model_to_load is not None and os .path .exists (model_to_load ):
212
- trainer = Trainer (args = args , pretrain = pretrain , vocab = vocab , model_file = model_to_load , device = args ['device' ], ignore_model_config = True )
206
+ if args .get ("checkpoint" ) and os .path .exists (args ["checkpoint_save_name" ]):
207
+ trainer = Trainer (args = args , pretrain = pretrain , vocab = vocab , model_file = args ["checkpoint_save_name" ], device = args ['device' ], ignore_model_config = True )
208
+ if len (trainer .dev_score_history ) > 0 :
209
+ logger .info ("Continuing from checkpoint %s Model was previously trained for %d steps, with a best dev score of %.4f" , args ["checkpoint_save_name" ], trainer .global_step , max (trainer .dev_score_history ))
210
+ elif args ["continue_from" ]:
211
+ if not os .path .exists (args ["continue_from" ]):
212
+ raise FileNotFoundError ("--continue_from specified, but the file %s does not exist" % args ["continue_from" ])
213
+ trainer = Trainer (args = args , pretrain = pretrain , vocab = vocab , model_file = args ["continue_from" ], device = args ['device' ], ignore_model_config = True , reset_history = True )
213
214
else :
214
215
trainer = Trainer (args = args , vocab = vocab , pretrain = pretrain , device = args ['device' ])
215
216
216
- global_step = 0
217
217
max_steps = args ['max_steps' ]
218
- dev_score_history = []
219
- best_dev_preds = []
220
218
current_lr = args ['lr' ]
221
219
global_start_time = time .time ()
222
220
format_str = 'Finished STEP {}/{}, loss = {:.6f} ({:.3f} sec/batch), lr: {:.6f}'
223
221
224
222
is_second_stage = False
225
- last_best_step = 0
226
223
# start training
227
224
train_loss = 0
228
225
while True :
229
226
do_break = False
230
227
for i , batch in enumerate (train_batch ):
231
228
start_time = time .time ()
232
- global_step += 1
229
+ trainer . global_step += 1
233
230
loss = trainer .update (batch , eval = False ) # update step
234
231
train_loss += loss
235
- if global_step % args ['log_step' ] == 0 :
232
+ if trainer . global_step % args ['log_step' ] == 0 :
236
233
duration = time .time () - start_time
237
- logger .info (format_str .format (global_step , max_steps , loss , duration , current_lr ))
234
+ logger .info (format_str .format (trainer . global_step , max_steps , loss , duration , current_lr ))
238
235
239
- if global_step % args ['eval_interval' ] == 0 :
236
+ if trainer . global_step % args ['eval_interval' ] == 0 :
240
237
# eval on dev
241
238
logger .info ("Evaluating on dev set..." )
242
239
dev_preds = []
@@ -250,60 +247,62 @@ def train(args):
250
247
_ , _ , dev_score = scorer .score (system_pred_file , gold_file )
251
248
252
249
train_loss = train_loss / args ['eval_interval' ] # avg loss per batch
253
- logger .info ("step {}: train_loss = {:.6f}, dev_score = {:.4f}" .format (global_step , train_loss , dev_score ))
250
+ logger .info ("step {}: train_loss = {:.6f}, dev_score = {:.4f}" .format (trainer . global_step , train_loss , dev_score ))
254
251
255
252
if args ['wandb' ]:
256
253
wandb .log ({'train_loss' : train_loss , 'dev_score' : dev_score })
257
254
258
255
train_loss = 0
259
256
260
257
# save best model
261
- if len (dev_score_history ) == 0 or dev_score > max (dev_score_history ):
262
- last_best_step = global_step
258
+ if len (trainer . dev_score_history ) == 0 or dev_score > max (trainer . dev_score_history ):
259
+ trainer . last_best_step = trainer . global_step
263
260
trainer .save (model_file )
264
261
logger .info ("new best model saved." )
265
- best_dev_preds = dev_preds
266
262
267
- dev_score_history += [dev_score ]
263
+ trainer . dev_score_history += [dev_score ]
268
264
269
265
if not is_second_stage and args .get ('second_optim' , None ) is not None :
270
- if global_step - last_best_step >= args ['max_steps_before_stop' ] or (args ['second_optim_start_step' ] is not None and global_step >= args ['second_optim_start_step' ]):
266
+ if trainer . global_step - trainer . last_best_step >= args ['max_steps_before_stop' ] or (args ['second_optim_start_step' ] is not None and trainer . global_step >= args ['second_optim_start_step' ]):
271
267
logger .info ("Switching to second optimizer: {}" .format (args .get ('second_optim' , None )))
272
268
args ["second_stage" ] = True
273
269
# if the loader gets a model file, it uses secondary optimizer
274
270
trainer = Trainer (args = args , vocab = trainer .vocab , pretrain = pretrain ,
275
271
model_file = model_file , device = args ['device' ])
276
272
logger .info ('Reloading best model to continue from current local optimum' )
277
273
is_second_stage = True
278
- last_best_step = global_step
274
+ trainer . last_best_step = trainer . global_step
279
275
else :
280
- if global_step - last_best_step >= args ['max_steps_before_stop' ]:
276
+ if trainer . global_step - trainer . last_best_step >= args ['max_steps_before_stop' ]:
281
277
do_break = True
282
278
break
283
279
284
- if global_step % args ['eval_interval' ] == 0 :
280
+ if trainer . global_step % args ['eval_interval' ] == 0 :
285
281
# if we need to save checkpoint, do so
286
282
# (save after switching the optimizer, if applicable, so that
287
283
# the new optimizer is the optimizer used if a restart happens)
288
284
if checkpoint_file is not None :
289
285
trainer .save (checkpoint_file , save_optimizer = True )
290
286
logger .info ("new model checkpoint saved." )
291
287
292
- if global_step >= args ['max_steps' ]:
288
+ if trainer . global_step >= args ['max_steps' ]:
293
289
do_break = True
294
290
break
295
291
296
292
if do_break : break
297
293
298
294
train_batch .reshuffle ()
299
295
300
- logger .info ("Training ended with {} steps." .format (global_step ))
296
+ logger .info ("Training ended with {} steps." .format (trainer . global_step ))
301
297
302
298
if args ['wandb' ]:
303
299
wandb .finish ()
304
300
305
- if len (dev_score_history ) > 0 :
306
- best_f , best_eval = max (dev_score_history )* 100 , np .argmax (dev_score_history )+ 1
301
+ if len (trainer .dev_score_history ) > 0 :
302
+ # TODO: technically the iteration position will be wrong if
303
+ # the eval_interval changed when running from a checkpoint
304
+ # could fix this by saving step & score instead of just score
305
+ best_f , best_eval = max (trainer .dev_score_history )* 100 , np .argmax (trainer .dev_score_history )+ 1
307
306
logger .info ("Best dev F1 = {:.2f}, at iteration = {}" .format (best_f , best_eval * args ['eval_interval' ]))
308
307
else :
309
308
logger .info ("Dev set never evaluated. Saving final model." )
0 commit comments