Skip to content

Commit 9d4c24e

Browse files
committed
Match the model definition in tensorflow, so that we can use pretrained tensorflow model.
And provide the script that can convert tensorflow model to pytorch model.
1 parent f7dee31 commit 9d4c24e

File tree

2 files changed

+78
-3
lines changed

2 files changed

+78
-3
lines changed

lib/nets/resnet_v1.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,13 @@ def __init__(self, block, layers, num_classes=1000):
118118
bias=False)
119119
self.bn1 = nn.BatchNorm2d(64)
120120
self.relu = nn.ReLU(inplace=True)
121+
# Note that, tf-faster-rcnn use padding 1 maxpool instead of ceil_mode, but it doesn't affect the output much
121122
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change
122123
self.layer1 = self._make_layer(block, 64, layers[0])
123124
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
124125
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
125-
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
126-
self.avgpool = nn.AvgPool2d(7)
127-
self.fc = nn.Linear(512 * block.expansion, num_classes)
126+
# use stride 1 for the last conv4 layer (same as tf-faster-rcnn)
127+
self.layer4 = self._make_layer(block, 512, layers[3], stride=1)
128128

129129
for m in self.modules():
130130
if isinstance(m, nn.Conv2d):

tools/convert_from_tensorflow.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import tensorflow as tf
2+
from tensorflow.python import pywrap_tensorflow
3+
from collections import OrderedDict
4+
import re
5+
import torch
6+
7+
import argparse
8+
parser = argparse.ArgumentParser(description='Convert tf-faster-rcnn model to pytorch-faster-rcnn model')
9+
parser.add_argument('--tensorflow_model',
10+
help='the path of tensorflow_model',
11+
default=None, type=str)
12+
13+
args = parser.parse_args()
14+
15+
reader = pywrap_tensorflow.NewCheckpointReader(args.tensorflow_model)
16+
var_to_shape_map = reader.get_variable_to_shape_map()
17+
var_dict = {k:reader.get_tensor(k) for k in var_to_shape_map.keys()}
18+
19+
del var_dict['Variable']
20+
21+
for k in var_dict.keys():
22+
if 'Momentum' in k:
23+
del var_dict[k]
24+
25+
for k in var_dict.keys():
26+
if k.find('/') >= 0:
27+
var_dict['resnet' + k[k.find('/'):]] = var_dict[k]
28+
del var_dict[k]
29+
30+
dummy_replace = OrderedDict([
31+
('moving_mean', 'running_mean'),\
32+
('moving_variance', 'running_var'),\
33+
('weights', 'weight'),\
34+
('biases', 'bias'),\
35+
('conv1/BatchNorm', 'bn1'),\
36+
('conv2/BatchNorm', 'bn2'),\
37+
('conv3/BatchNorm', 'bn3'),\
38+
('bottleneck_v1/', ''),\
39+
('block', 'layer'),\
40+
('resnet/rpn_conv/3x3', 'rpn_net'),\
41+
('resnet/rpn_cls_score', 'rpn_cls_score_net'),\
42+
('resnet/cls_score', 'cls_score_net'),\
43+
('resnet/rpn_bbox_pred', 'rpn_bbox_pred_net'),\
44+
('resnet/bbox_pred', 'bbox_pred_net'),\
45+
('shortcut/weight', 'downsample.0.weight'),\
46+
('shortcut/BatchNorm', 'downsample.1'),\
47+
('gamma', 'weight'),\
48+
('beta', 'bias'),\
49+
('/', '.')])
50+
51+
for a, b in dummy_replace.items():
52+
for k in var_dict.keys():
53+
if a in k:
54+
var_dict[k.replace(a,b)] = var_dict[k]
55+
del var_dict[k]
56+
57+
58+
for k in var_dict.keys():
59+
if 'unit_' in k:
60+
m = re.search('unit_(\d+)', k)
61+
var_dict[k.replace(m.group(0), str(int(m.group(1)) - 1))] = var_dict[k]
62+
del var_dict[k]
63+
64+
for k in var_dict.keys():
65+
if var_dict[k].ndim == 4:
66+
var_dict[k] = var_dict[k].transpose((3, 2, 0, 1)).copy(order='C')
67+
if var_dict[k].ndim == 2:
68+
var_dict[k] = var_dict[k].transpose((1, 0)).copy(order='C')
69+
# assert x[k].shape == var_dict[k].shape, k
70+
71+
for k in var_dict.keys():
72+
var_dict[k] = torch.from_numpy(var_dict[k])
73+
74+
75+
torch.save(var_dict, args.tensorflow_model[:args.tensorflow_model.find('.ckpt')]+'.pth')

0 commit comments

Comments
 (0)