keras的训练权值转成tensorflow的权值pb文件

  因为目前keras比较容易上手,很多网络现在都是用keras做训练,得到的权值结果保存在h5文件中。如果想要在工程部署上使用离线调用,一般都是转成pb文件,尤其在C++调用时,tensorflow有专门为C++调用的API,非常好用。

1. 此文生成的keras权值是基于keras2.1.0,python3.6.5。

2. 保存权值时使用的代码是:model.save()函数。

3. 定义h5文件转pb文件的函数:在python下完成

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    graph = session.graph
    with graph.as_default():
        input_graph_def = graph.as_graph_def()
        frozen_graph = graph_util.convert_variables_to_constants(session, input_graph_def, output_names)  
        return frozen_graph

4. 调用freeze_session()函数

from keras.models import load_model
import tensorflow as tf
import os
import os.path as osp
from keras import backend as K
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import graph_io


input_fld = './h5_weight/' 
weight_file = 'k_g_27_image_0_c1.h5'

output_graph_name = 'k_g_27_image_0_c1.pb'
output_fld = './pb_weight/'

if not os.path.isdir(output_fld):
    os.mkdir(output_fld)

weight_file_path = osp.join(input_fld, weight_file)
K.set_learning_phase(0)
net_model = load_model(weight_file_path)  
sess = K.get_session()

frozen_graph = freeze_session(sess, output_names=[net_model.output.op.name])
graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=False)

在测试过程中 ,函数write_graph的参数as_text的值设置成true时,在后面读取pb文件时会出错。

5. 在python中调用转换的pb文件进行预测

import tensorflow as tf
from PIL import Image
import numpy as np
import os


def freeze_graph_test(pb_path, image_path, out_path):
    '''
    :param pb_path:pb文件的路径
    :param image_data:测试图片的路径
    :return:
    '''
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()
        with open(pb_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            tf.import_graph_def(output_graph_def, name="")
            tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
            for tensor_name in tensor_name_list:
                print(tensor_name, '\n')


        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            input_image_tensor = sess.graph.get_tensor_by_name("input_1:0")  # 定义输入的张量名称,对应网络结构的输入张量
            output_tensor_name = sess.graph.get_tensor_by_name("conv2d_19/Sigmoid:0")  # 定义输出的张量名称。
            

            im = read_image(image_path)  # 读取测试图片, 按照训练输入的数据格式准备
            im=np.expand_dims(im, axis=3)

            out = sess.run(output_tensor_name, feed_dict={input_image_tensor: im}) 




image_path = "./data_predict_c1/test/images/7.jpg"
pb_path = "./pb_weight/k_g_27_image_0_c1.pb"
out_path = "./data_predict_c1/predict/"
if not os.path.exists(out_path):
    os.makedirs(out_path)

freeze_graph_test(pb_path=pb_path, image_path=image_path, out_path=out_path)

 

以上就是keras的h5权值文件转tensorflow的pb文件的整个过程。

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值