Tensorflow基于pb模型进行预训练(pb模型转CKPT模型)

Tensorflow基于pb模型进行预训练(pb模型转CKPT模型)

在网上看到很多教程都是tensorflow基于pb模型进行推理,而不是进行预训练。最近在在做项目的过程中发现之前的大哥只有一个pb模型留给我。。。。从头训练的时间又太长,因此还是决定来将pb模型中的参数析出来变成能进行预训练的ckpt模型(主要是不想去改原本的训练代码)

1.首先要加载pb文件

def load_model(model_file, input_map=None):
    with tf.gfile.FastGFile(model_file, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, input_map=input_map, name='')
       

2.获取所有的操作节点

一般我们加载了图之后,都是去获得他的占位符去进行输入,然后输出.为了得到所有的权重,我们使用g.get_operations()获得所有的操作节点.

import tensorflow as tf

if __name__ == "__main__":
    g = tf.get_default_graph()
    sess = tf.Session()
    load_model('pretrained/mobilenetv1_1.0.pb')
    optlist = g.get_operations()
    optlist = list(optlist)

3.获得tensor

optlist就是操作列表,操作列表中的数据如下所示:

In [11]: optlist
Out[11]:
[<tf.Operation 'inputs' type=Placeholder>,
 <tf.Operation 'MobileNetV1/SpaceToBatchND/block_shape' type=Const>,
 <tf.Operation 'MobileNetV1/SpaceToBatchND/paddings' type=Const>,
 <tf.Operation 'MobileNetV1/SpaceToBatchND' type=SpaceToBatchND>,
 <tf.Operation 'MobileNetV1/Conv2d_0_3x3/weights' type=Const>,
 <tf.Operation 'MobileNetV1/Conv2d_0_3x3/weights/read' type=Identity>,
 <tf.Operation 'MobileNetV1/Conv2d_0_3x3/Conv2D' type=Conv2D>,
 <tf.Operation 'MobileNetV1/Conv2d_0_3x3/BatchNorm/beta' type=Const>,
 <tf.Operation 'MobileNetV1/Conv2d_0_3x3/BatchNorm/beta/read' type=Identity>

此时,我们需要读取预存的权重,通过资料直接说明xxxx/read等操作就是读取预存的权重的操作.因此我们可以直接把这些操作过滤出来

def get_vars_from_optlist(optlist: list)->list:
    """ 从optlist获得所有的变量节点 """
    varlist = [node for node in optlist if '/read' in node.name]
    return varlist

现在我们有个对应的读取变量操作列表,但是要读取变量还是要进行转化,因为varlist只是一个操作,还没有变成可运行的tensor,所以我只要在操作名后面加上:0,同时get_tensor_by_name()即可得到对应的tensor

def convert_vars_to_tensor(g, varlist: list)->list:
    """ 把varlist中的操作转变为可运行的tensor """
    tensorlist = []
    for var in varlist:
        tensorlist.append(g.get_tensor_by_name(var.name+':0'))
    return tensorlist

4.获取变量

有了tensorlist,我们可以来读取变量了.为了restore的方便,我将他保存成字典的形式,并且修改每一个key都与原图中的变量名相同,这样restore的时候直接判断名字是否相同即可.

# 将所有变量存入字典
vardict = {}
for v in tensorlist:
    vardict[v.name.replace('/read', '')] = sess.run(v)

5.保存字典

使用下面的这个函数保存我们的vardict

def save_pkl(obj, name):
    with open(name, 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

这个函数中obj就是我们要保存的vardict,name是要保存的文件路径

由此,我们可以得到一个存着网络层和对应权重的pickle文件
整套代码如下所示:

import tensorflow as tf
import pickle

def load_model(model_file, input_map=None):
    with tf.gfile.FastGFile(model_file, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, input_map=input_map, name='')

def get_vars_from_optlist(optlist: list)->list:
    """ 从optlist获得所有的变量节点 """
    varlist = [node for node in optlist if '/read' in node.name]
    return varlist


def convert_vars_to_tensor(g, varlist: list)->list:
    """ 把varlist中的操作转变为可运行的tensor """
    tensorlist = []
    for var in varlist:
        tensorlist.append(g.get_tensor_by_name(var.name+':0'))
    return tensorlist

def save_pkl(obj, name):
    with open(name, 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

if __name__ == "__main__":
    model_file = './frozen_model_small.pb'
    g = tf.get_default_graph()
    sess = tf.Session()
    load_model(model_file)
    optlist = g.get_operations()
    optlist = list(optlist)
    varlist = get_vars_from_optlist(optlist)
    tensorlist = convert_vars_to_tensor(g, varlist)
    # 将所有变量存入字典
    vardict = {}
    for v in tensorlist:
        vardict[v.name.replace('/read', '')] = sess.run(v)
    
    file_name = './frozen_model_small_pb.pickle'
    save_pkl(vardict, file_name)

6.转ckpt模型

接下来我做了一个骚操作- -我首先从头训练了这个模型,保存了一个能进行预训练的ckpt模型,接下来就是将之前得到的pickle文件里的weight值一一替换这个ckpt里的weight值,形成一个新的ckpt模型。(这一步要将模型本身的网络引入,即必须要有原图的定义)

import tensorflow as tf
import pickle
# 自己的网络模型
from infer.networks import model



if __name__ == "__main__":
    pklpath = "./frozen_model_small_pb.pickle"

    with open(pklpath, 'rb') as f:
        pre_weight_dict = pickle.load(f)
    
    config = tf.ConfigProto(allow_soft_placement=True)
    graph = tf.Graph()
    with graph.as_default():
        with tf.Session(config=config) as sess:
            gpu_options = tf.GPUOptions(allow_growth=True)
            sess = tf.Session(
                    config=tf.ConfigProto(
                        allow_soft_placement=True, gpu_options=gpu_options))


            input_image = tf.placeholder(tf.float32, [None, 448, 448, 3], name="input_img")
            # postprocessed_dict = tf.placeholder(tf.float32, [None], name="output")
            postprocessed_dict = model.modelv2(input_image)
            saver  = tf.train.Saver()
            saver.restore(sess,"./model_direction.ckpt-1")
    
            opt_list = []
            for k, oldv in pre_weight_dict.items():
                fq = graph.get_tensor_by_name(k)
                opt_list.append(tf.assign(fq, oldv))
            
            sess.run(opt_list)
            saver.save(sess, './pb_to_ckpt_model.ckpt',global_step=1)

参考资料

https://zhen8838.github.io/2019/01/28/pb-to-pkl/

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值