【模型转换】caffe模型转换为tensorflow的pb文件

地址: https://github.com/ethereon/caffe-tensorflow 

转换需要提供caffe的网络模型文件deploy.pt和参数数据文件.caffemodel。

比如 caffe 定义的cnn网络 ResNet-101的转换命令行:

./convert.py ResNet-101-deploy.pt --caffemodel ResNet-101-model.caffemodel --code-output-path=ResNetTF.py --data-output-path=ResNetTF.npy 
 

将.npy文件转换为 .pb 模型文件的方法通过以下代码实现:

import tensorflow as tf
import detect_face
import os
from tensorflow.python.framework.graph_util import convert_variables_to_constants
 
 
# tf.InteractiveSession()让自己成为默认的session,用户不需指明哪个session运行情况下就可运行起来
# tf.InteractiveSession()来构建会话时,可以先构建一个session然后再定义操作
# 使用tf.Session()来构建会话,需要在会话构建之前定义好全部的操作(operation)然后再构建会话
sess = tf.InteractiveSession()
 
with tf.variable_scope("pnet"):
    data = tf.placeholder(tf.float32, (None, None, None, 3), "input")
    pnet = detect_face.PNet({"data": data})
    pnet.load("det1.npy", sess)
 
with tf.variable_scope("rnet"):
    data = tf.placeholder(tf.float32, (None, 24, 24, 3), "input")
    rnet = detect_face.RNet({"data": data})
    rnet.load("det2.npy", sess)
 
with tf.variable_scope("onet"):
    data = tf.placeholder(tf.float32, (None, 48, 48, 3), "input")
    onet = detect_face.ONet({"data": data})
    onet.load("det3.npy", sess)
 
# 将存储到 .npy文件中的网络模型参数转换成用 .bp文件存储的模型格式
"""
tf模型导出为单个文件(同时包含模型架构定义与权重)
利用tf.train_write_graph()默认情况下只导出了网络的定义(无权重)
利用tf.train_Saver().Save()导出文件graph_def与权重分离的,graph_def没有包含网络中的Variable值
(通常情况只存储了权重),但却包含了constant值,如果把Variable转换成constant,
可达到使用一个文件同时存储网络架构与权重的目标
convert_variables_to_constants函数会将计算图中的变量取值以常量的形式保存,
在保存模型文件的时候只是导出了GraphDef部分,GraphDef保存了从输入到输出的计算过程
保存的时候通过convert_variables_to_constants函数来指定保存的节点名称而不是张量的名称
比如:“add:0”是张量的名称,而"add"表示的是节点的名称。
"""
#
constant_graph = convert_variables_to_constants(sess, sess.graph_def,
                                                ["pnet/input", "rnet/input", "onet/input",
                                                 "pnet/conv4-2/BiasAdd", "pnet/prob1",
                                                 "rnet/conv5-2/conv5-2", "rnet/prob1",
                                                 "onet/conv6-2/conv6-2", "onet/conv6-3/conv6-3",
                                                 "onet/prob1"])
 
 
with tf.gfile.FastGFile("face_detect.pb", mode="wb") as f:
    f.write(constant_graph.SerializeToString())

参考:https://blog.csdn.net/zchang81/article/details/76229017

          https://blog.csdn.net/loveliuzz/article/details/81363272

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值