File tree 1 file changed +46
-0
lines changed
1 file changed +46
-0
lines changed Original file line number Diff line number Diff line change
1
+ # huangjj-pkusz-pcl
2
+ # load the pretrain fpn(it should be a checkpoint, and the last checkpoint is loaded as default) to initialize the mix-up model
3
+ import logging
4
+ import torch
5
+ import os
6
+ from maskrcnn_benchark .utils .model_serialization import load_state_dict
7
+
8
+
9
+ def get_checkpoint_file (path ):
10
+ save_file = path
11
+ try :
12
+ with open (save_file , "r" ) as f :
13
+ last_saved = f .read ()
14
+ last_saved = last_saved .strip ()
15
+ except IOError :
16
+ # if file doesn't exist, maybe because it has just been
17
+ # deleted by a separate process
18
+ last_saved = ""
19
+ return last_saved
20
+
21
+ def load_file (cfg ,f ):
22
+ if f .endswith (".pkl" ):
23
+ return load_c2_format (cfg , f )
24
+ return torch .load (f , map_location = torch .device ("cpu" ))
25
+
26
+ def load_model (model , checkpoint )
27
+ load_state_dict (model , checkpoint .pop ("model" ))
28
+ return model
29
+
30
+ def load_pretrain_detector (cfg , model )
31
+ logger = logging .getLogger (__name__ )
32
+
33
+ path_to_model = cfg .PRETRAIN_DET_DIR
34
+ save_file = os .path .join (path_to_model , "last_checkpoint" )
35
+
36
+ if os .path .exists (save_file ):
37
+ f = get_model_file (path_to_model )
38
+ if not f :
39
+ logger .infor ("No pretrained detecotr found." )
40
+ exit (0 )
41
+
42
+ checkpoint = load_file (cfg ,f )
43
+
44
+ return load_model (model , checkpoint )
45
+
46
+
You can’t perform that action at this time.
0 commit comments