tensor转list_tensorflow中ckpt转pb

1db41a00447c32c29a5f370d9b2a7135.png

ckpt转pb

ckpt转pb有两种方式,一种是通过.meta文件加载图和ckpt文件固化成ckpt,一种是加载运行一次网络固化成pb,下面分别介绍

加载`meta`文件固化

加载.meta文件恢复网络graph,然后加载ckpt,将变量转换为constant常量,再移除ckpt中保留的训练相关但和前向推理无关的结点。

import tensorflow as tf
from tensorflow.python.framework import graph_util

def ckpt2pb():
    with tf.Graph().as_default() as graph_old:
        isess = tf.InteractiveSession()

        ckpt_filename = './model.ckpt'
        isess.run(tf.global_variables_initializer())
        saver = tf.train.import_meta_graph(ckpt_filename+'.meta', clear_devices=True)
        saver.restore(isess, ckpt_filename)

        constant_graph = graph_util.convert_variables_to_constants(isess, isess.graph_def, ["Cls/fc/biases"])
        constant_graph = graph_util.remove_training_nodes(constant_graph)

        with tf.gfile.GFile('./pb/model.pb', mode='wb') as f:
            f.write(constant_graph.SerializeToString())

运行一次网络固化

运行网络加载graph的优点是可以在网络起止位置手动添加结点标记,方便定义网络起止结点的名字。

def inference(is_training, img_input):
    '''your network
    '''

def ckpt2pb2():
    with tf.Graph().as_default() as graph_old:
        img_input = tf.placeholder(tf.float32, shape=(None, 40, 120, 3))
        model = inference(False, img_input)

        isess = tf.InteractiveSession()
        ckpt_filename ='./model.ckpt'

        isess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        saver.restore(isess, ckpt_filename)

        constant_graph = graph_util.convert_variables_to_constants(isess, isess.graph_def, ["Cls/fc/BiasAdd"])
        constant_graph = graph_util.remove_training_nodes(constant_graph)

        with tf.gfile.GFile('./pb/model.pb', mode='wb') as f:
            f.write(constant_graph.SerializeToString())

使用pb做前向推理

import tensorflow as tf
import cv2
import numpy as np

def inference_use_pb():
    graph_path = './pb/module.pb'
    graph_def = tf.GraphDef()
    with tf.gfile.FastGFile(graph_path,'rb') as f:
        graph_def.ParseFromString(f.read())
        _ = importer.import_graph_def(graph_def, name="")

    isess = tf.InteractiveSession()
    images_placeholder = tf.get_default_graph().get_tensor_by_name("placeholder:0")

    embeddings = tf.get_default_graph().get_tensor_by_name("Cls/fc/BiasAdd:0")

    img = cv2.imread('./1.png')
    image = cv2.resize(img, (120, 40)).astype(np.float32) / 255.0
    image = np.reshape(image, (40,120,3))

    res = isess.run([embeddings], feed_dict={images_placeholder: np.reshape(image, [1, 40, 120, 3])})
    print(res)

读取ckpt文件

有时我们需要从ckpt文件中获取权重,则可以使用以下方法

import tensorflow as tf

ckpt_path = './model.ckpt'

with tf.Session() as sess:
    for var_name, _ in tf.contrib.framework.list_variables(ckpt_path):
        print(var_name)
        var = tf.contrib.framework.load_variable(ckpt_path, var_name)
        print(var.shape)
        print(var)

读取pb文件

当没有ckpt文件,只有pb文件时,使用以下方法获取权重

from tensorflow.python.platform import gfile
from tensorflow.python.framework import tensor_util

graph_path = './model.pb'

def values_from_const(node_def):
    if node_def.op != "Const":
        raise ValueError("Node named '%s' should be a Const op for values_from_const." % node_def.name)
    input_tensor = node_def.attr["value"].tensor
    tensor_value = tensor_util.MakeNdarray(input_tensor)
    return tensor_value

def read_pb():
    input_graph_def = graph_pb2.GraphDef()
    with gfile.Open(graph_path, "rb") as f:
        data = f.read()
        input_graph_def.ParseFromString(data)

    for node in input_graph_def.node:
        print(node.name)
        print(node.op)
        if node.op == "Const":
            if 'weights' in node.name:
                weight = values_from_const(node).reshape(-1)
                print(weight)

if __name__ == "__main__":
    read_pb()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值