Skip to content

Commit ca8fae6

Browse files
hjjpkuzhangpzh
authored andcommitted
haungjj:load pretrained detector for mix up model (facebookresearch#2)
2333
1 parent 94b25ec commit ca8fae6

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

maskrcnn_benchmark/huangjj.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# huangjj-pkusz-pcl
2+
# load the pretrain fpn(it should be a checkpoint, and the last checkpoint is loaded as default) to initialize the mix-up model
3+
import logging
4+
import torch
5+
import os
6+
from maskrcnn_benchark.utils.model_serialization import load_state_dict
7+
8+
9+
def get_checkpoint_file(path):
10+
save_file = path
11+
try:
12+
with open(save_file, "r") as f:
13+
last_saved = f.read()
14+
last_saved = last_saved.strip()
15+
except IOError:
16+
# if file doesn't exist, maybe because it has just been
17+
# deleted by a separate process
18+
last_saved = ""
19+
return last_saved
20+
21+
def load_file(cfg,f):
22+
if f.endswith(".pkl"):
23+
return load_c2_format(cfg, f)
24+
return torch.load(f, map_location=torch.device("cpu"))
25+
26+
def load_model(model, checkpoint)
27+
load_state_dict(model, checkpoint.pop("model"))
28+
return model
29+
30+
def load_pretrain_detector(cfg, model)
31+
logger = logging.getLogger(__name__)
32+
33+
path_to_model = cfg.PRETRAIN_DET_DIR
34+
save_file = os.path.join(path_to_model, "last_checkpoint")
35+
36+
if os.path.exists(save_file):
37+
f = get_model_file(path_to_model)
38+
if not f:
39+
logger.infor("No pretrained detecotr found.")
40+
exit(0)
41+
42+
checkpoint = load_file(cfg,f)
43+
44+
return load_model(model, checkpoint)
45+
46+

0 commit comments

Comments
 (0)