|
| 1 | +import tensorflow as tf |
| 2 | +from tensorflow.python import pywrap_tensorflow |
| 3 | +from collections import OrderedDict |
| 4 | +import re |
| 5 | +import torch |
| 6 | + |
| 7 | +import argparse |
| 8 | +parser = argparse.ArgumentParser(description='Convert tf-faster-rcnn model to pytorch-faster-rcnn model') |
| 9 | +parser.add_argument('--tensorflow_model', |
| 10 | + help='the path of tensorflow_model', |
| 11 | + default=None, type=str) |
| 12 | + |
| 13 | +args = parser.parse_args() |
| 14 | + |
| 15 | +reader = pywrap_tensorflow.NewCheckpointReader(args.tensorflow_model) |
| 16 | +var_to_shape_map = reader.get_variable_to_shape_map() |
| 17 | +var_dict = {k:reader.get_tensor(k) for k in var_to_shape_map.keys()} |
| 18 | + |
| 19 | +del var_dict['Variable'] |
| 20 | + |
| 21 | +for k in var_dict.keys(): |
| 22 | + if 'Momentum' in k: |
| 23 | + del var_dict[k] |
| 24 | + |
| 25 | +for k in var_dict.keys(): |
| 26 | + if k.find('/') >= 0: |
| 27 | + var_dict['resnet' + k[k.find('/'):]] = var_dict[k] |
| 28 | + del var_dict[k] |
| 29 | + |
| 30 | +dummy_replace = OrderedDict([ |
| 31 | + ('moving_mean', 'running_mean'),\ |
| 32 | + ('moving_variance', 'running_var'),\ |
| 33 | + ('weights', 'weight'),\ |
| 34 | + ('biases', 'bias'),\ |
| 35 | + ('conv1/BatchNorm', 'bn1'),\ |
| 36 | + ('conv2/BatchNorm', 'bn2'),\ |
| 37 | + ('conv3/BatchNorm', 'bn3'),\ |
| 38 | + ('bottleneck_v1/', ''),\ |
| 39 | + ('block', 'layer'),\ |
| 40 | + ('resnet/rpn_conv/3x3', 'rpn_net'),\ |
| 41 | + ('resnet/rpn_cls_score', 'rpn_cls_score_net'),\ |
| 42 | + ('resnet/cls_score', 'cls_score_net'),\ |
| 43 | + ('resnet/rpn_bbox_pred', 'rpn_bbox_pred_net'),\ |
| 44 | + ('resnet/bbox_pred', 'bbox_pred_net'),\ |
| 45 | + ('shortcut/weight', 'downsample.0.weight'),\ |
| 46 | + ('shortcut/BatchNorm', 'downsample.1'),\ |
| 47 | + ('gamma', 'weight'),\ |
| 48 | + ('beta', 'bias'),\ |
| 49 | + ('/', '.')]) |
| 50 | + |
| 51 | +for a, b in dummy_replace.items(): |
| 52 | + for k in var_dict.keys(): |
| 53 | + if a in k: |
| 54 | + var_dict[k.replace(a,b)] = var_dict[k] |
| 55 | + del var_dict[k] |
| 56 | + |
| 57 | + |
| 58 | +for k in var_dict.keys(): |
| 59 | + if 'unit_' in k: |
| 60 | + m = re.search('unit_(\d+)', k) |
| 61 | + var_dict[k.replace(m.group(0), str(int(m.group(1)) - 1))] = var_dict[k] |
| 62 | + del var_dict[k] |
| 63 | + |
| 64 | +for k in var_dict.keys(): |
| 65 | + if var_dict[k].ndim == 4: |
| 66 | + var_dict[k] = var_dict[k].transpose((3, 2, 0, 1)).copy(order='C') |
| 67 | + if var_dict[k].ndim == 2: |
| 68 | + var_dict[k] = var_dict[k].transpose((1, 0)).copy(order='C') |
| 69 | + # assert x[k].shape == var_dict[k].shape, k |
| 70 | + |
| 71 | +for k in var_dict.keys(): |
| 72 | + var_dict[k] = torch.from_numpy(var_dict[k]) |
| 73 | + |
| 74 | + |
| 75 | +torch.save(var_dict, args.tensorflow_model[:args.tensorflow_model.find('.ckpt')]+'.pth') |
0 commit comments