Tensorflow中checkPoint到tflite模型的转换

一、ckpt模型转换为frozen_graph.pb模型

import tensorflow as tf

def freeze_graph(input_checkpoint, output_graph):
    output_node_names = "output" #获取的节点
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
    graph = tf.get_default_graph()  # 获得默认的图
    input_graph_def = graph.as_graph_def()  # 返回一个序列化的图代表当前的图

    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)  # 恢复图并得到数据
        output_graph_def = tf.graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
            sess=sess,
            input_graph_def=input_graph_def,  # 等于:sess.graph_def
            output_node_names=output_node_names.split(","))  # 如果有多个输出节点,以逗号隔开

        with tf.gfile.GFile(output_graph, "wb") as f:  # 保存模型
            f.write(output_graph_def.SerializeToString())  # 序列化输出
        # print("%d ops in the final graph." % len(output_graph_def.node))  # 得到当前图有几个操作节点

if __name__ == '__main__':
    modelpath="./checkPointModel/model.ckpt"
    freeze_graph(modelpath,"frozen.pb")
    print("finish!")

二、frozen_graph.pb模型转换为tflite模型

import tensorflow as tf

convert=tf.lite.TFLiteConverter.from_frozen_graph("frozen.pb",input_arrays=["train_x"],output_arrays=["output"])
convert.post_training_quantize=True
tflite_model=convert.convert()
open("model.tflite","wb").write(tflite_model)

当需要给定输入数据形式时,给出输入格式:
import tensorflow as tf
path="./fullLayer/"
convert=tf.lite.TFLiteConverter.from_frozen_graph(path+"frozen.pb",input_arrays=["images"],output_arrays=["output"],
                                                  input_shapes={"images":[1,540,960,1]})
convert.post_training_quantize=True
tflite_model=convert.convert()
open(path+"quantized_model.tflite","wb").write(tflite_model)
print("finish!")

三、调用PB文件模型

 

import tensorflow as tf
import cv2 as cv
import numpy as np

if __name__=="__main__":
    test_pb_model=True
    test_tflite_model=False
    read_cahnge_graph=False

    pb_model_path="./fullLayer/frozen.pb"
    tflite_model_path = "./fullLayer/quantized_model.tflite"  # layer2   fullLayer
    input_node_name="iamges"
    output_node_name="output"

    src_img=cv.imread("1.jpg")
    cv.imwrite("src.jpg",src_img)

    src_img=cv.resize(src_img,(960,540))
    src_img=cv.cvtColor(src_img,cv.COLOR_BGR2YCrCb)[:,:,0]
    src_img=src_img/127.5-1
    src_img=src_img.astype("float32")

    src_img=src_img.reshape((1,540,960,1))

    if test_tflite_model:
        interpreter=tf.lite.Interpreter(tflite_model_path)
        interpreter.allocate_tensors()

        input_details=interpreter.get_input_details()
        # print(str(input_details))
        output_details=interpreter.get_output_details()

        interpreter.set_tensor(input_details[0]["index"],src_img)

        interpreter.invoke()
        output_data=interpreter.get_tensor(output_details[0]["index"])

        result=output_data[0]

        result=(result+1)*127.5
        result[result>255]=255
        result[result<0]=0
        result=result.astype(np.uint8)
        cv.imshow("result",result)
        cv.imwrite("result.jpg", result)
        cv.waitKey()
    if test_pb_model:
        src_img=cv.imread("1.jpg")
        src_img=cv.resize(src_img,(960,540))
        src_img=cv.cvtColor(src_img,cv.COLOR_BGR2YCrCb)[:,:,0]
        src_img=src_img/127.5-1
        src_img=src_img.astype("float32")
        src_img=src_img.reshape((1,540,960,1))
        input_image=tf.placeholder(tf.float32,(1,540,960,1))

        with open(pb_model_path,"rb") as f:
            graph_def=tf.GraphDef()
            graph_def.ParseFromString(f.read())
            out_result=tf.import_graph_def(graph_def,input_map={"images:0":input_image},return_elements=["output:0"])
        sess=tf.Session()
        result=sess.run(out_result,feed_dict={input_image:src_img})

        result=result[0][0]
        result=(result+1)*127.5
        result[result>255]=255
        result[result<0]=0
        result=result.astype(np.uint8)

        cv.imshow("resut",result)
        cv.waitKey()
    if read_cahnge_graph:
        gf=tf.GraphDef()
        gf.ParseFromString(open(pb_model_path,"rb").read())
        for n in gf.node:
            print(n.name + " ===> "+n.op )

四、调用tflite模型

import tensorflow as tf
import cv2 as cv
import numpy as np

from tensorflow.examples.tutorials.mnist import input_data
minst=input_data.read_data_sets("./mnist/",one_hot=True)
index=100
one_minist=minst.train.images[index]
one_minist_img=one_minist.reshape((1,28,28,1))
print("image real value:{}".format(np.argmax(minst.train.labels[index],0)))

test_image_dir = 'test.png'
#model_path = "./model/quantize_frozen_graph.tflite"
model_path = "model.tflite"

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
# print(str(input_details))
output_details = interpreter.get_output_details()
# print(str(output_details))

src=cv.imread(test_image_dir,cv.IMREAD_GRAYSCALE)
src=cv.resize(src,(28,28))
# cv.imshow("gray_img",one_minist_img.reshape([28,28]))
# cv.waitKey()


# 增加一个维度,变为 [1, 784]
src = np.expand_dims(src, axis=0)
src = np.expand_dims(src, axis=3)
# print(src.shape)
src = src.astype('float32')  # 类型也要满足要求


# src=one_minist_img  #测试minist数据集


# 填装数据
interpreter.set_tensor(input_details[0]['index'], src)


interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])


# 出来的结果去掉没用的维度
result = np.squeeze(output_data)
# print('result:{}'.format(result))

# 输出结果是长度为10(对应0-9)的一维数据,最大值的下标就是预测的数字
print('result:{}'.format((np.where(result == np.max(result)))[0][0]))

 

&*&一些坑

1.tensorflow1.12.0版本在pb模型转换为tflite时会出现‘No module named ‘_tensorflow_wrap_toco’’,搜索了下竟然是官方的问题。升级为tf-nightly1.13问题解决了。  在安装f-nightly1.13前先卸载原有的tensorflow版本,安装后可能会遇到numpy.core...无法import情况,先卸载numpy,删除其残留文件,再安装numpy。

  • 5
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值