tensorflow1.14 转换pytorch 部署inference 所遇到问题 superpoint

GitHub - rpautrat/SuperPoint: Efficient neural feature detector and descriptor

1.checkpoint转换pb

SuperPoint/export_model.py at master · rpautrat/SuperPoint · GitHub

2.freeze_graph

参考1:

Python freeze_graph.freeze_graph方法代碼示例

Python freeze_graph.freeze_graph方法代碼示例 - 純淨天空

参考2:

tensorflow 模型导出总结

tensorflow 模型导出总结 - 知乎

参考3:

from tensorflow.python.tools import freeze_graph
from tensorflow.python.saved_model import tag_constants

input_saved_model_dir = "pb_model"

output_ver = False
if output_ver == True:
    output_node_names = "superpoint/prob_nms,superpoint/descriptors_raw,\
                                superpoint/logits,superpoint/prob,\
                                superpoint/descriptors,superpoint/pred"
else:
    output_node_names = "superpoint/prob_nms,superpoint/descriptors"

input_binary = False
input_saver_def_path = False
restore_op_name = None
filename_tensor_name = None
clear_devices = False
input_meta_graph = False
checkpoint_path = None
input_graph_filename = None
saved_model_tags = tag_constants.SERVING
output_graph_filename='frozen_graph2.pb'

freeze_graph.freeze_graph(input_graph_filename,
                input_saver_def_path,
                input_binary,
                checkpoint_path,
                output_node_names,
                restore_op_name,
                filename_tensor_name,
                output_graph_filename,
                clear_devices,
                "", "", "",
                input_meta_graph,
                input_saved_model_dir,
                saved_model_tags)

'''
freeze_graph总共有11个参数,以下逐一介绍下,供大家参考:

input_graph:模型文件,可以是二进制的pb文件,或文本的meta文件,用input_binary来指定区分。
        我们的例子中,使用了二进制的pb文件,对应input_binary就是False

input_saver:Saver解析器,主要用于版本不兼容时使用。通常为空,为空时用当前版本的Saver

input_binary:配合input_graph用,为true时,input_graph为二进制,为false时,input_graph为文件。默认值是False

input_checkpoint:checkpoint文件地址

output_node_names:输出节点的名字,有多个时用逗号分开,我们的输出节点是'out',
        这是我们使用flow = tf.cast(flow, tf.int8, 'out')将模型的输出节点命名为out。
        如果没有这一步的操作,我们可以找到模型的输出节点名是什么,并且在这一参数中对应。

restore_op_name:从模型恢复节点的名字,一般使用默认:save/restore_all

filename_tensor_name:一般使用默认:save/Const:0

output_graph:用来保存整合后的模型输出文件,即pb文件的保存地址

clear_devices:指定是否清除训练时节点指定的运算设备(如cpu、gpu、tpu。cpu是默认),默认True

initializer_nodes:默认空。权限加载后,可通过此参数来指定需要初始化的节点,用逗号分隔多个节点名字。

variable_names_blacklist:默认空。变量黑名单,用于指定不用恢复值的变量,用逗号分隔多个变量名字。
'''

3. inference 部署(freeze_graph pb) 

参考1 参考量很大

Frozen-Graph-TensorFlow/test_pb.py at master · leimao/Frozen-Graph-TensorFlow · GitHub

Lei Mao's Log Book – Save, Load and Inference From TensorFlow Frozen Graph

参考2 参考量很大

[深度学习] TensorFlow中模型的freeze_graph

[深度学习] TensorFlow中模型的freeze_graph - 知乎

在TensorFlow中模型的保存和调用,相信大家都不会陌生,使用关键语句saver = tf.train.Saver()和saver.save()就可以完成。

但是,不知道大家是否了解,tensorflow通过checkpoint这一种格式文件,是将模型的结构和权重数据分开保存的,这就造成了一些使用场景下的不方便。

所以,我们需要一种方式将模型结构和权重数据合并在一个文件中,tensorflow提供了freeze_graph函数和pb文件格式,来解决这一问题。

这些模型文件是做什么的

在save之后,模型会保存在ckpt文件中,checkpoint文件保存了一个目录下所有的模型文件列表,events文件是给可视化工具tensorboard用的。

和保存的模型直接相关的是以下这三个文件:

  • .data文件保存了当前参数值
  • .index文件保存了当前参数名
  • .meta文件保存了当前图结构

当你使用saver.restore()载入模型时,你用的就是这一组的三个checkpoint文件。

有哪些相见恨晚的 TensorFlow 小技巧?

有哪些相见恨晚的 TensorFlow 小技巧? - 知乎

import os
import tensorflow as tf

os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # 指定哪张卡:0, 1, 2, ...

config = tf.ConfigProto()
config.gpu_options.allow_growth = True  # 随着进程逐渐增加显存占用,而不是一下占满
session = tf.Session(config=config, ...)

参考3 参考量不大

https://github.com/tensorflow/tensorflow/blob/r1.14/tensorflow/examples/tutorials/deepdream/deepdream.ipynb

参考4 参考量不大

tensorflow将训练好的模型freeze,即将权重固化到图里面,并使用该模型进行预测

tensorflow将训练好的模型freeze,即将权重固化到图里面,并使用该模型进行预测_lujiandong1的专栏-CSDN博客

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import tensorflow as tf
import numpy as np
from tensorflow.python.framework import tensor_util

# If load from pb, you may have to use get_tensor_by_name heavily.

class SuperPoint_TF(object):
    def __init__(self, model_filepath, input_size=[None, 640, 480, 1], print_node=False, print_log=False):
        # The file path of model
        self.model_filepath = model_filepath
        self.input_size = input_size
        self.print_node = print_node
        self.print_log = print_log
        # Initialize the model
        self.load_graph(model_filepath=self.model_filepath)


    def load_graph(self, model_filepath):
        '''
        Lode trained model.
        '''
        print('Loading model...')
        self.graph = tf.Graph()

        # 载入的是freeze_graph pb
        with tf.gfile.GFile(model_filepath, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())

        if self.print_log==True:
            print('Check out the input placeholders:')
            nodes = [
                n.name + ' => ' + n.op for n in graph_def.node
                if n.op in ('Placeholder')
            ]
            for node in nodes:
                print(node)

        # 一个用于将当前图作为默认图的上下文管理器
        with self.graph.as_default():
            # Define input tensor
            # 定义输入的图像
            self.input = tf.placeholder(np.float32,
                                        shape=self.input_size,
                                        name='superpoint/image')
            # 导入模型
            # 导入后新的name会在原始图的name上添加,import/
            tf.import_graph_def(graph_def, {
                'superpoint/image': self.input,
            })

        # Tensorflow深度学习之三十四:tf.Graph.finalize()
        # https://blog.csdn.net/DaVinciL/article/details/84251917
        # 把整个图冻住,使图变为只读的形式,不再允许增加节点
        self.graph.finalize()

        print('Model loading complete!')

        if self.print_log==True:
            # Get layer names
            layers = [op.name for op in self.graph.get_operations()]
            for layer in layers:
                print(layer)

        if self.print_node==True:
            # Check out the weights of the nodes
            weight_nodes = [n for n in graph_def.node if n.op == 'Const']
            for n in weight_nodes:
                print("Name of the node - %s" % n.name)
                print("Value - " )
                print(tensor_util.MakeNdarray(n.attr['value'].tensor))

        # In this version, tf.InteractiveSession and tf.Session could be used interchangeably.
        # self.sess = tf.InteractiveSession(graph = self.graph)
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True  # 随着进程逐渐增加显存占用,而不是一下占满
        self.sess = tf.Session(graph=self.graph, config=config)

    def kps_detect(self, data):

        # Know your output node name
        output_tensor1 = self.graph.get_tensor_by_name("import/superpoint/prob_nms:0")
        output_tensor2 = self.graph.get_tensor_by_name("import/superpoint/descriptors:0")
        output = self.sess.run([output_tensor1, output_tensor2],
                               feed_dict={
                                   self.input: data,
                               })
        return output


model_filepath = 'superpoint/frozen_graph2.pb'
tf.reset_default_graph()
model = SuperPoint_TF(model_filepath=model_filepath, input_size=[None, 640, 480, 1])


while True:
    one_image = np.zeros((1,640,480,1))
    test_prediction = model.kps_detect(data=one_image)
    print(test_prediction)

如果想获取tf中间层的feature map

通过run获取

参考:

tensorflow查看feature map的shape

tensorflow查看feature map的shape_runningwei的博客-CSDN博客

from logging import debug
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
#os.environ['TF_CPP_MIN_VLOG_LEVEL']='2'

import tensorflow as tf
import numpy as np
np.set_printoptions(suppress=True)
from tensorflow.python.framework import tensor_util

# If load from pb, you may have to use get_tensor_by_name heavily.

class SuperPoint_TF(object):
    def __init__(self, model_filepath, input_size=[None, 640, 480, 1], print_node=False, print_log=False):
        # The file path of model
        self.model_filepath = model_filepath
        self.input_size = input_size
        self.print_node = print_node
        self.print_log = print_log
        # Initialize the model
        self.load_graph(model_filepath=self.model_filepath)


    def load_graph(self, model_filepath):
        '''
        Lode trained model.
        '''
        print('Loading model...')
        self.graph = tf.Graph()

        # 载入的是freeze_graph pb
        with tf.gfile.GFile(model_filepath, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())

        if self.print_log==True:
            print('Check out the input placeholders:')
            nodes = [
                n.name + ' => ' + n.op for n in graph_def.node
                if n.op in ('Placeholder')
            ]
            for node in nodes:
                print(node)

        # 一个用于将当前图作为默认图的上下文管理器
        with self.graph.as_default():
            # Define input tensor
            # 定义输入的图像
            self.input = tf.placeholder(np.float32,
                                        shape=self.input_size,
                                        name='superpoint/image')
            # 导入模型
            # 导入后新的name会在原始图的name上添加,import/
            tf.import_graph_def(graph_def, {
                'superpoint/image': self.input,
            })

        # Tensorflow深度学习之三十四:tf.Graph.finalize()
        # https://blog.csdn.net/DaVinciL/article/details/84251917
        # 把整个图冻住,使图变为只读的形式,不再允许增加节点
        self.graph.finalize()

        print('Model loading complete!')

        if self.print_log==True:
            # Get layer names
            layers = [op.name for op in self.graph.get_operations()]
            for layer in layers:
                print(layer)

        if self.print_node==True:
            # Check out the weights of the nodes
            weight_nodes = [n for n in graph_def.node if n.op == 'Const']
            for n in weight_nodes:
                print("Name of the node - %s" % n.name)
                print("Value - " )
                print(tensor_util.MakeNdarray(n.attr['value'].tensor))

        # 查看每一层的feature map
        operations = self.graph.get_operations()
        #self.op_names = ['Placeholder']
        self.op_names = []
        for op in operations:
            if op.name.split('/')[-1] == 'FusedBatchNorm':
                #print(op.name)
                self.op_names.append(op.name)

        # In this version, tf.InteractiveSession and tf.Session could be used interchangeably.
        # self.sess = tf.InteractiveSession(graph = self.graph)
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True  # 随着进程逐渐增加显存占用,而不是一下占满
        self.sess = tf.Session(graph=self.graph, config=config)

    def kps_detect(self, data):

        # Know your output node name
        output_tensor1 = self.graph.get_tensor_by_name("import/superpoint/prob_nms:0")
        output_tensor2 = self.graph.get_tensor_by_name("import/superpoint/descriptors:0")
        output = self.sess.run([output_tensor1, output_tensor2],
                               feed_dict={
                                   self.input: data,
                               })
        return output

    def print_feature_map_shape(self, data):

        for op in self.op_names:
            cur_op_featuremap = self.sess.run(self.graph.get_tensor_by_name(op + ':0'), 
                               feed_dict={
                                   self.input: data,
                               })
            print(op, cur_op_featuremap.shape)

    def get_feature_map(self, data):

        res_featuremap = {}
        for op in self.op_names:
            cur_op_featuremap = self.sess.run(self.graph.get_tensor_by_name(op + ':0'), 
                               feed_dict={
                                   self.input: data,
                               })
            res_featuremap[op] = cur_op_featuremap
        return res_featuremap


model_filepath = 'frozen_graph2.pb'
tf.reset_default_graph()

#model = SuperPoint_TF(model_filepath=model_filepath, input_size=[None, 640, 480, 1])
#one_image = np.zeros((1,640,480,1))
#test_prediction = model.kps_detect(data=one_image)

model = SuperPoint_TF(model_filepath=model_filepath, input_size=[None, 480, 640, 1])
# 打印feature map的shape
model.print_feature_map_shape(data=left_img)
# 得到feature map
res_fm = model.get_feature_map(data=left_img)
np.save('输入输出对齐tf/tf_res_fm.npy', res_fm) 


# 通过ckpt得到的结果和预处理后的输入
# 用于验证此pb方式和原始ckpt之间的误差
left_img = np.load('输入输出对齐tf/left_img.npy')
tf_out = np.load('输入输出对齐tf/out.npy', allow_pickle=True).item()

left_img = np.expand_dims(left_img, 0)
test_prediction = model.kps_detect(data=left_img)
prediction0 = np.squeeze(test_prediction[0])
prediction1 = np.squeeze(test_prediction[1])
prediction1 = np.transpose(prediction1, (2, 0, 1))

# check pb模型是否对齐
align_prob_nms = prediction0 - tf_out['prob_nms']
align_descriptors = prediction1- tf_out['descriptors']
print("mean: align_prob_nms ", np.sum(np.abs(align_prob_nms)), np.mean(np.abs(align_prob_nms)), )
print("mean: align_descriptors ", np.sum(np.abs(align_descriptors)), np.mean(np.abs(align_descriptors)), )


mean: align_prob_nms  0.00013339985 4.342443e-10
mean: align_descriptors  4.786112 6.085856e-08
import/superpoint/pred_tower0/vgg/conv1_1/bn/FusedBatchNorm (1, 480, 640, 64)
import/superpoint/pred_tower0/vgg/conv1_2/bn/FusedBatchNorm (1, 480, 640, 64)
import/superpoint/pred_tower0/vgg/conv2_1/bn/FusedBatchNorm (1, 240, 320, 64)
import/superpoint/pred_tower0/vgg/conv2_2/bn/FusedBatchNorm (1, 240, 320, 64)
import/superpoint/pred_tower0/vgg/conv3_1/bn/FusedBatchNorm (1, 120, 160, 128)
import/superpoint/pred_tower0/vgg/conv3_2/bn/FusedBatchNorm (1, 120, 160, 128)
import/superpoint/pred_tower0/vgg/conv4_1/bn/FusedBatchNorm (1, 60, 80, 128)
import/superpoint/pred_tower0/vgg/conv4_2/bn/FusedBatchNorm (1, 60, 80, 128)
import/superpoint/pred_tower0/detector/conv1/bn/FusedBatchNorm (1, 60, 80, 256)
import/superpoint/pred_tower0/detector/conv2/bn/FusedBatchNorm (1, 60, 80, 65)
import/superpoint/pred_tower0/descriptor/conv1/bn/FusedBatchNorm (1, 60, 80, 256)
import/superpoint/pred_tower0/descriptor/conv2/bn/FusedBatchNorm (1, 60, 80, 256)

4. inference

tensorflow和pytorch模型之间转换_ 杨杨的博客-CSDN博客_tensorflow转pytorch

tfcheckpoint2pytorch/tfcheckpoint2pytorch.py at master · vadimkantorov/tfcheckpoint2pytorch · GitHub

会遇到问题

ValueError: the ``/`` character is not allowed in object names: 'superpoint/vgg/conv4_2/conv/kernel'

其实是pytable库的问题

换成其它方式保存:


import tensorflow as tf
import argparse
import os
import numpy as np
from tensorflow.python.pywrap_tensorflow import NewCheckpointReader

def tr(v):
    # tensorflow weights to pytorch weights
    if v.ndim == 4:
        return np.ascontiguousarray(v.transpose(3,2,0,1))
    elif v.ndim == 2:
        return np.ascontiguousarray(v.transpose())
    return v

def read_ckpt(ckpt):
    
    # https://github.com/tensorflow/tensorflow/issues/1823
    reader = tf.train.NewCheckpointReader(ckpt)
    weights = {n: reader.get_tensor(n) for (n, _) in reader.get_variable_to_shape_map().items()}

    #reader = NewCheckpointReader(tf.train.latest_checkpoint(ckpt))
    #weights = {k : reader.get_tensor(k) for k in reader.get_variable_to_shape_map()}
    
    pyweights = {k: tr(v) for (k, v) in weights.items()}
    return pyweights

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Converts ckpt weights to deepdish hdf5")

    parser.add_argument("--infile", type=str, default='tf转换pytorch/model/last.ckpt-600000',
                        help="Path to the ckpt.")

    parser.add_argument("--outfile", type=str, nargs='?', default='tf转换pytorch/model/last.h5',
                        help="Output file (inferred if missing).")

    args = parser.parse_args()

    if args.outfile == '':
        args.outfile = os.path.splitext(args.infile)[0] + '.h5'

    outdir = os.path.dirname(args.outfile)
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    weights = read_ckpt(args.infile)

    result = {}
    for m_ind, m_val in enumerate(weights.keys()):
        new_val = m_val.replace("/", "_and_")
        result[new_val] = weights[m_val]

    import deepdish as dd
    dd.io.save(args.outfile, result)

将.h5转换成 .pth

import torch
import torch.nn as nn
import torch.nn.functional as F
import h5py
import numpy as np
from HSR_tf_superpoint_backbone_pytorch import SuperPointBNNet

import deepdish as dd
pretrained_dict = dd.io.load('tf转换pytorch/model/last.h5')
new_pre_dict = {}
for k,v in pretrained_dict.items():
    new_k = k.replace("_and_", ".")
    if new_k.endswith("Adam") or new_k.endswith("Adam_1"):
        continue
    if new_k.endswith("beta1_power") or new_k.endswith("beta2_power"):
        continue
    if new_k.endswith("global_step"):
        continue   
    new_pre_dict[new_k] = torch.Tensor(v)
    #print("'"+new_k+"': None,")

weight_name_dict = {
    'superpoint.descriptor.conv1.bn.beta': None,
    'superpoint.descriptor.conv1.bn.gamma': None,
    'superpoint.descriptor.conv1.bn.moving_mean': None,
    'superpoint.descriptor.conv1.bn.moving_variance': None,
    'superpoint.descriptor.conv1.conv.bias': None,
    'superpoint.descriptor.conv1.conv.kernel': None,
    'superpoint.descriptor.conv2.bn.beta': None,
    'superpoint.descriptor.conv2.bn.gamma': None,
    'superpoint.descriptor.conv2.bn.moving_mean': None,
    'superpoint.descriptor.conv2.bn.moving_variance': None,
    'superpoint.descriptor.conv2.conv.bias': None,
    'superpoint.descriptor.conv2.conv.kernel': None,
    'superpoint.detector.conv1.bn.beta': None,
    'superpoint.detector.conv1.bn.gamma': None,
    'superpoint.detector.conv1.bn.moving_mean': None,
    'superpoint.detector.conv1.bn.moving_variance': None,
    'superpoint.detector.conv1.conv.bias': None,
    'superpoint.detector.conv1.conv.kernel': None,
    'superpoint.detector.conv2.bn.beta': None,
    'superpoint.detector.conv2.bn.gamma': None,
    'superpoint.detector.conv2.bn.moving_mean': None,
    'superpoint.detector.conv2.bn.moving_variance': None,
    'superpoint.detector.conv2.conv.bias': None,
    'superpoint.detector.conv2.conv.kernel': None,
    'superpoint.vgg.conv1_1.bn.beta': 'vgg.conv1_1.conv1_1_conv.weight',
    'superpoint.vgg.conv1_1.bn.gamma': 'vgg.conv1_1.conv1_1_conv.bias',
    'superpoint.vgg.conv1_1.bn.moving_mean': None,
    'superpoint.vgg.conv1_1.bn.moving_variance': None,
    'superpoint.vgg.conv1_1.conv.bias': None,
    'superpoint.vgg.conv1_1.conv.kernel': None,
    'superpoint.vgg.conv1_2.bn.beta': None,
    'superpoint.vgg.conv1_2.bn.gamma': None,
    'superpoint.vgg.conv1_2.bn.moving_mean': None,
    'superpoint.vgg.conv1_2.bn.moving_variance': None,
    'superpoint.vgg.conv1_2.conv.bias': None,
    'superpoint.vgg.conv1_2.conv.kernel': None,
    'superpoint.vgg.conv2_1.bn.beta': None,
    'superpoint.vgg.conv2_1.bn.gamma': None,
    'superpoint.vgg.conv2_1.bn.moving_mean': None,
    'superpoint.vgg.conv2_1.bn.moving_variance': None,
    'superpoint.vgg.conv2_1.conv.bias': None,
    'superpoint.vgg.conv2_1.conv.kernel': None,
    'superpoint.vgg.conv2_2.bn.beta': None,
    'superpoint.vgg.conv2_2.bn.gamma': None,
    'superpoint.vgg.conv2_2.bn.moving_mean': None,
    'superpoint.vgg.conv2_2.bn.moving_variance': None,
    'superpoint.vgg.conv2_2.conv.bias': None,
    'superpoint.vgg.conv2_2.conv.kernel': None,
    'superpoint.vgg.conv3_1.bn.beta': None,
    'superpoint.vgg.conv3_1.bn.gamma': None,
    'superpoint.vgg.conv3_1.bn.moving_mean': None,
    'superpoint.vgg.conv3_1.bn.moving_variance': None,
    'superpoint.vgg.conv3_1.conv.bias': None,
    'superpoint.vgg.conv3_1.conv.kernel': None,
    'superpoint.vgg.conv3_2.bn.beta': None,
    'superpoint.vgg.conv3_2.bn.gamma': None,
    'superpoint.vgg.conv3_2.bn.moving_mean': None,
    'superpoint.vgg.conv3_2.bn.moving_variance': None,
    'superpoint.vgg.conv3_2.conv.bias': None,
    'superpoint.vgg.conv3_2.conv.kernel': None,
    'superpoint.vgg.conv4_1.bn.beta': None,
    'superpoint.vgg.conv4_1.bn.gamma': None,
    'superpoint.vgg.conv4_1.bn.moving_mean': None,
    'superpoint.vgg.conv4_1.bn.moving_variance': None,
    'superpoint.vgg.conv4_1.conv.bias': None,
    'superpoint.vgg.conv4_1.conv.kernel': None,
    'superpoint.vgg.conv4_2.bn.beta': None,
    'superpoint.vgg.conv4_2.bn.gamma': None,
    'superpoint.vgg.conv4_2.bn.moving_mean': None,
    'superpoint.vgg.conv4_2.bn.moving_variance': None,
    'superpoint.vgg.conv4_2.conv.bias': None,
    'superpoint.vgg.conv4_2.conv.kernel': None,
}

config = {
    "grid_size": 8,
    # pytorch
    #"det_thresh": 0.015,
    # tensorflow
    "det_thresh": 0.001,
    "nms": 4,
    "topk": -1,
}

net = SuperPointBNNet(config=config, input_channel=1, grid_size=8, device='cpu')
model_dict = net.state_dict()

tf_torch_conv_dict = {
    'weight': 'kernel',
    'bias': 'bias',
}
tf_torch_bn_dict = {
    'running_mean': 'moving_mean',
    'running_var': 'moving_variance',
    'weight': 'gamma',
    'bias': 'beta',
}
result_weight_dict = {}
for m_k,m_v in model_dict.items():
    if m_k.endswith("num_batches_tracked"):
        result_weight_dict[m_k] = m_v
        continue
    #print(m_k)
    if m_k.split('.')[0]=="vgg":
        tf_vgg_conv_name = m_k.split('.')[1]
        tf_subconv = m_k.split('.')[2].split('_')[-1]
        if tf_subconv=="conv":
            tf_w = tf_torch_conv_dict[m_k.split('.')[3]]
        elif tf_subconv=="bn":
            tf_w = tf_torch_bn_dict[m_k.split('.')[3]]
        tf_name = "{}.{}.{}.{}.{}".format("superpoint", "vgg", tf_vgg_conv_name, tf_subconv, tf_w)
    elif m_k.split('.')[0]=="detector_head":
        tf_vgg_conv_name = m_k.split('.')[1]
        tf_subconv = m_k.split('.')[2].split('_')[-1]
        if tf_subconv=="conv":
            tf_w = tf_torch_conv_dict[m_k.split('.')[3]]
        elif tf_subconv=="bn":
            tf_w = tf_torch_bn_dict[m_k.split('.')[3]]
        tf_name = "{}.{}.{}.{}.{}".format("superpoint", "detector", tf_vgg_conv_name, tf_subconv, tf_w)
    elif m_k.split('.')[0]=="descriptor_head":
        tf_vgg_conv_name = m_k.split('.')[1]
        tf_subconv = m_k.split('.')[2].split('_')[-1]
        if tf_subconv=="conv":
            tf_w = tf_torch_conv_dict[m_k.split('.')[3]]
        elif tf_subconv=="bn":
            tf_w = tf_torch_bn_dict[m_k.split('.')[3]]
        tf_name = "{}.{}.{}.{}.{}".format("superpoint", "descriptor", tf_vgg_conv_name, tf_subconv, tf_w)
    
    weight_name_dict[tf_name] = m_k
    result_weight_dict[m_k] = new_pre_dict[tf_name]

#更新
model_dict.update(result_weight_dict)
#加载
net.load_state_dict(model_dict)
torch.save(obj=net.state_dict(), f="tf转换pytorch/model/pytorch_net.pth")

参考:

tensorflow查看feature map的shape

tensorflow查看feature map的shape_runningwei的博客-CSDN博客

pytorch中feature map的可视化

pytorch中feature map的可视化_温瞳-CSDN博客

如果想获取pytorch中间层的feature map

通过hook的方式,但是现在遇到误差较大的问题,其中pytorch模型参数是上一步保存的参数

pytorch_net.pth

import os
from tokenize import PseudoExtras
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import torch
torch.set_grad_enabled(False)
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import numpy as np
import cv2
from HSR_tf_superpoint_backbone_pytorch import SuperPointBNNet

def extract_superpoint_keypoints_and_descriptors(keypoint_map, descriptor_map,
                                                keep_k_points=1000):
    def select_k_best(points, k):
        """ Select the k most probable points (and strip their proba).
        points has shape (num_points, 3) where the last coordinate is the proba. """
        #top-k (y z)
        #sorted_prob = points[points[:, 2].argsort(), :2]
        #top-k (y z score)
        sorted_prob = points[points[:, 2].argsort(), :]
        start = min(k, points.shape[0])
        return sorted_prob[-start:, :]

    # Extract keypoints
    keypoints = np.where(keypoint_map > 0)
    prob = keypoint_map[keypoints[0], keypoints[1]]
    keypoints = np.stack([keypoints[0], keypoints[1], prob], axis=-1)

    keypoints = select_k_best(keypoints, keep_k_points)
    keypoints_yz = keypoints[:, :2].astype(np.int)
    keypoints_score = keypoints[:, -1].astype(np.float)

    # Get descriptors for keypoints
    desc = descriptor_map[keypoints_yz[:, 0], keypoints_yz[:, 1]]

    # (y,x,置信度) 转换为 (x,y,置信度)
    return keypoints_yz[:, [1, 0]], keypoints_score, desc


config = {
    "grid_size": 8,
    # pytorch
    #"det_thresh": 0.015,
    # tensorflow
    "det_thresh": 0.001,
    "nms": 4,
    # pytorch
    # "topk": -1,
    # tensorflow
    "topk": 0,
}

sp_det = SuperPointBNNet(config=config, input_channel=1, grid_size=8, device='cpu')
sp_det.load_state_dict(torch.load("/home/ninghua/深度学习/特征点检测/tf转换pytorch/superpoint_hsr_H320_16w/model/pytorch_net.pth"))
#data_transform = transforms.Compose([transforms.ToTensor()])

left_img = np.load('/home/ninghua/深度学习/特征点检测/tf转换pytorch/superpoint_hsr_H320_16w/model/输入输出对齐tf/left_img.npy')
left_orig = np.load('/home/ninghua/深度学习/特征点检测/tf转换pytorch/superpoint_hsr_H320_16w/model/输入输出对齐tf/left_orig.npy')
tf_out = np.load('/home/ninghua/深度学习/特征点检测/tf转换pytorch/superpoint_hsr_H320_16w/model/输入输出对齐tf/out.npy', allow_pickle=True).item()

left_img = left_img.transpose((2,0,1))
image = torch.from_numpy(left_img)
image = image.unsqueeze(0).cuda()
sp_det.cuda()
sp_det.eval()
# # 显存占用 1766MiB
# with torch.no_grad():
#     torch_out = sp_det(image)

pytorch_mid_res = []
def get_image_name_func(module):
    module_str = str(module)
    image_name = module_str.rsplit(':',1)[0].split('(')[-1].rsplit('_',1)[0]
    return image_name

def hook_func(module,input,output):
    '''
    param input:
    param output:
    param module:
    '''
    module_str = str(module)
    data = output.clone().detach()
    #data = data.permute(1, 0, 2, 3)#(n,c,h,w)->(c,n,h,w)
    data = data.permute(0, 2, 3, 1)#(n,c,h,w)->(n,h,w,c)
    pytorch_mid_res.append(data)
    
# 获取中间层 feature map
exact_list = [
    'vgg.conv1_1.conv1_1_bn',
    'vgg.conv1_2.conv1_2_bn',
    'vgg.conv2_1.conv2_1_bn',
    'vgg.conv2_2.conv2_2_bn',
    'vgg.conv3_1.conv3_1_bn',
    'vgg.conv3_2.conv3_2_bn',
    'vgg.conv4_1.conv4_1_bn',
    'vgg.conv4_2.conv4_2_bn',
    'detector_head.conv1.conv1_bn',
    'detector_head.conv2.conv2_bn',
    'descriptor_head.conv1.conv1_bn',
    'descriptor_head.conv2.conv2_bn',
]
for name, module in sp_det.named_modules():
    # print(name)
    if name in exact_list:
        module.register_forward_hook(hook_func)
# 显存占用 1766MiB
with torch.no_grad():
    torch_out = sp_det(image)

# check feature map是否对齐
tf_res_fm = np.load('/home/ninghua/深度学习/特征点检测/tf转换pytorch/superpoint_hsr_H320_16w/model/输入输出对齐tf/tf_res_fm.npy', allow_pickle=True).item()
count = 0
for m_k, m_v in tf_res_fm.items():
    print(m_k, " mean: ")
    tf_FM = m_v
    torch_FM = pytorch_mid_res[count].cpu().numpy()
    print(tf_FM.shape, " ", torch_FM.shape)
    res_mean = np.abs(tf_FM - torch_FM)
    print("mean: {} ".format(m_k), np.sum(res_mean), np.mean(res_mean))
    count += 1

tf_prob_nms = tf_out['prob_nms']
tf_descriptors = tf_out['descriptors']
tf_keypoint_map = np.squeeze(tf_prob_nms)
tf_descriptor_map = np.squeeze(tf_descriptors)
tf_descriptor_map = np.transpose(tf_descriptor_map, (1, 2, 0))
tf_kp1, tf_kpscore, torch_desc1 = extract_superpoint_keypoints_and_descriptors(tf_keypoint_map, tf_descriptor_map)

torch_prob_nms = torch_out['det_dict']['prob_nms']
torch_descriptors = torch_out['desc_dict']['desc']
torch_prob_nms = torch_prob_nms.cpu().numpy()
torch_descriptors = torch_descriptors.cpu().numpy()
torch_keypoint_map = np.squeeze(torch_prob_nms)
torch_descriptor_map = np.squeeze(torch_descriptors)
torch_descriptor_map = np.transpose(torch_descriptor_map, (1, 2, 0))
torch_kp1, torch_kpscore, torch_desc1 = extract_superpoint_keypoints_and_descriptors(torch_keypoint_map, torch_descriptor_map)

res_mean_out1 = np.abs(tf_keypoint_map - torch_keypoint_map)
print("res_mean_out1:  ", np.sum(res_mean_out1), np.mean(res_mean_out1))
res_mean_out2 = np.abs(tf_descriptor_map - torch_descriptor_map)
print("res_mean_out2:  ", np.sum(res_mean_out2), np.mean(res_mean_out2))

np.testing.assert_allclose(tf_descriptor_map,
                           torch_descriptor_map,
                           rtol=0.01, atol=0)

np.testing.assert_allclose(tf_keypoint_map,
                           torch_keypoint_map,
                           rtol=0.01, atol=0)

vis_tf = left_orig.copy()
vis_torch = left_orig.copy()
h1, w1 = vis_tf.shape[:2]
h2, w2 = vis_torch.shape[:2]
vis = np.zeros((max(h1, h2), w1 + w2, 3), np.uint8)

# vis point
green = (0, 0, 255)
for (x1, y1) in tf_kp1:
    col = green
    cv2.circle(vis_tf, (x1, y1), 2, col, -1)

green = (0, 0, 255)
for (x2, y2) in torch_kp1:
    col = green
    cv2.circle(vis_torch, (x2, y2), 2, col, -1)

vis[:h1, :w1, :] = vis_tf
vis[:h2, w1:w1 + w2, :] = vis_torch
cv2.imwrite("vis_result.png", vis)

vgg
vgg.conv1_1
vgg.conv1_1.conv1_1_conv
vgg.conv1_1.conv1_1_rlue
vgg.conv1_1.conv1_1_bn
vgg.conv1_2
vgg.conv1_2.conv1_2_conv
vgg.conv1_2.conv1_2_rlue
vgg.conv1_2.conv1_2_bn
vgg.pool1
vgg.conv2_1
vgg.conv2_1.conv2_1_conv
vgg.conv2_1.conv2_1_rlue
vgg.conv2_1.conv2_1_bn
vgg.conv2_2
vgg.conv2_2.conv2_2_conv
vgg.conv2_2.conv2_2_rlue
vgg.conv2_2.conv2_2_bn
vgg.pool2
vgg.conv3_1
vgg.conv3_1.conv3_1_conv
vgg.conv3_1.conv3_1_rlue
vgg.conv3_1.conv3_1_bn
vgg.conv3_2
vgg.conv3_2.conv3_2_conv
vgg.conv3_2.conv3_2_rlue
vgg.conv3_2.conv3_2_bn
vgg.pool3
vgg.conv4_1
vgg.conv4_1.conv4_1_conv
vgg.conv4_1.conv4_1_rlue
vgg.conv4_1.conv4_1_bn
vgg.conv4_2
vgg.conv4_2.conv4_2_conv
vgg.conv4_2.conv4_2_rlue
vgg.conv4_2.conv4_2_bn
detector_head
detector_head.conv1
detector_head.conv1.conv1_conv
detector_head.conv1.conv1_rlue
detector_head.conv1.conv1_bn
detector_head.conv2
detector_head.conv2.conv2_conv
detector_head.conv2.conv2_bn
detector_head.softmax
descriptor_head
descriptor_head.conv1
descriptor_head.conv1.conv1_conv
descriptor_head.conv1.conv1_rlue
descriptor_head.conv1.conv1_bn
descriptor_head.conv2
descriptor_head.conv2.conv2_conv
descriptor_head.conv2.conv2_bn
import/superpoint/pred_tower0/vgg/conv1_1/bn/FusedBatchNorm  mean: 
(1, 480, 640, 64)   (1, 480, 640, 64)
mean: import/superpoint/pred_tower0/vgg/conv1_1/bn/FusedBatchNorm  4480126.5 0.22787102
import/superpoint/pred_tower0/vgg/conv1_2/bn/FusedBatchNorm  mean: 
(1, 480, 640, 64)   (1, 480, 640, 64)
mean: import/superpoint/pred_tower0/vgg/conv1_2/bn/FusedBatchNorm  22140818.0 1.1261402
import/superpoint/pred_tower0/vgg/conv2_1/bn/FusedBatchNorm  mean: 
(1, 240, 320, 64)   (1, 240, 320, 64)
mean: import/superpoint/pred_tower0/vgg/conv2_1/bn/FusedBatchNorm  2442779.0 0.49698466
import/superpoint/pred_tower0/vgg/conv2_2/bn/FusedBatchNorm  mean: 
(1, 240, 320, 64)   (1, 240, 320, 64)
mean: import/superpoint/pred_tower0/vgg/conv2_2/bn/FusedBatchNorm  2027317.8 0.41245887
import/superpoint/pred_tower0/vgg/conv3_1/bn/FusedBatchNorm  mean: 
(1, 120, 160, 128)   (1, 120, 160, 128)
mean: import/superpoint/pred_tower0/vgg/conv3_1/bn/FusedBatchNorm  941567.6 0.38312486
import/superpoint/pred_tower0/vgg/conv3_2/bn/FusedBatchNorm  mean: 
(1, 120, 160, 128)   (1, 120, 160, 128)
mean: import/superpoint/pred_tower0/vgg/conv3_2/bn/FusedBatchNorm  867199.75 0.35286447
import/superpoint/pred_tower0/vgg/conv4_1/bn/FusedBatchNorm  mean: 
(1, 60, 80, 128)   (1, 60, 80, 128)
mean: import/superpoint/pred_tower0/vgg/conv4_1/bn/FusedBatchNorm  132582.7 0.21579216
import/superpoint/pred_tower0/vgg/conv4_2/bn/FusedBatchNorm  mean: 
(1, 60, 80, 128)   (1, 60, 80, 128)
mean: import/superpoint/pred_tower0/vgg/conv4_2/bn/FusedBatchNorm  131524.86 0.21407041
import/superpoint/pred_tower0/detector/conv1/bn/FusedBatchNorm  mean: 
(1, 60, 80, 256)   (1, 60, 80, 256)
mean: import/superpoint/pred_tower0/detector/conv1/bn/FusedBatchNorm  283121.47 0.23040484
import/superpoint/pred_tower0/detector/conv2/bn/FusedBatchNorm  mean: 
(1, 60, 80, 65)   (1, 60, 80, 65)
mean: import/superpoint/pred_tower0/detector/conv2/bn/FusedBatchNorm  333244.3 1.0680908
import/superpoint/pred_tower0/descriptor/conv1/bn/FusedBatchNorm  mean: 
(1, 60, 80, 256)   (1, 60, 80, 256)
mean: import/superpoint/pred_tower0/descriptor/conv1/bn/FusedBatchNorm  179328.12 0.1459376
import/superpoint/pred_tower0/descriptor/conv2/bn/FusedBatchNorm  mean: 
(1, 60, 80, 256)   (1, 60, 80, 256)
mean: import/superpoint/pred_tower0/descriptor/conv2/bn/FusedBatchNorm  88100.3 0.07169621
res_mean_out1:   96.13022 0.0003129239
res_mean_out2:   2670920.5 0.03396251

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值