tensorflow模型处理(查看节点、模型转换、模型测试)

本章提供了tensorflow模型常用的一些处理方法,包括:

tensorboard查看ckpt网络结构

tensorboard查看pb网络结构

ckpt模型转pb模型

pb模型转pbtxt文件

测试pb模型

pb模型转tflite模型

测试tflite模型

h5模型转pb模型

测试caffe模型

1.1 查看ckpt网络结构_v1.py

运行脚本

在控制台输入命令:tensorboard --logdir=d:/log  --host=127.0.0.1

浏览器输入:http://127.0.0.1:6006/

from tensorflow.python import pywrap_tensorflow
import os
import tensorflow as tf
from tensorflow.python.platform import gfile

# checkpoint_path = os.path.join('model/model.ckpt')
# reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
# var_to_shape_map = reader.get_variable_to_shape_map()
# for key in var_to_shape_map:
#     print('tensor_name: ', key)

ckpt_path = os.path.join('test1/model/model.ckpt-15776')
saver = tf.train.import_meta_graph(ckpt_path+'.meta',clear_devices=True)
graph = tf.get_default_graph()
with tf.Session( graph=graph) as sess:
    sess.run(tf.global_variables_initializer()) 
    saver.restore(sess,ckpt_path)

    # tensorboard --logdir=d:/log  --host=127.0.0.1
    init = tf.initialize_all_variables()
    merged = tf.summary.merge_all()
    writer = tf.summary.FileWriter("d:/log/",sess.graph) #目录结构尽量简单,复杂了容易出现找不到文件,原因不清楚
    sess.run(init)

1.2 查看ckpt网络结构_v2.py

import tensorflow as tf
from tensorflow.summary import FileWriter
 
sess = tf.Session()
tf.train.import_meta_graph('zhou/latest_model.ckpt.meta')
FileWriter("d:/log/", sess.graph)

# tensorboard --logdir=d:/log  --host=127.0.0.1

2.查看pb网络结构.py

import tensorflow as tf
from tensorflow.python.platform import gfile
with tf.Session() as sess:
    # model_filename ="saved_model/fasterrcnn_resnet101_const.pb"
    model_filename ="test1/frozen_inference_graph.pb"
    with gfile.FastGFile(model_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        g_in = tf.import_graph_def(graph_def)
        # LOGDIR="D:\\temp\\Russia_FasterRcnn\\log"
        
        # tensorboard --logdir=d:/log  --host=127.0.0.1
        LOGDIR="d:/log/"
        train_writer = tf.summary.FileWriter(LOGDIR)
        train_writer.add_graph(sess.graph)

3.ckpt2pb.py

import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python import pywrap_tensorflow


def freeze_graph(cpkt_path, pb_path):
    # checkpoint = tf.train.get_checkpoint_state('zhou/') #检查目录下ckpt文件状态是否可用
    # cpkt_path2 = checkpoint.model_checkpoint_path #得ckpt文件路径
    # cpkt_path3 = checkpoint.all_model_checkpoint_paths
    # print("model_pa:",cpkt_path3)

    # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
    # output_node_names = "num_detections,raw_detection_boxes,raw_detection_scores"
    output_node_names = 'logistic_loss/mul'
    saver = tf.train.import_meta_graph(cpkt_path + '.meta', clear_devices=True)
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()
    print(graph)
    # feature_data_list = input_graph_def.get_operation_by_name('resnet_v2_50/conv1').outputs[0]
    # input_image=tf.placeholder(None,28,28,1)

    with tf.Session() as sess:
        saver.restore(sess, cpkt_path)  # 恢复图并得到数据

        pb_path_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
            sess=sess,
            input_graph_def=input_graph_def,  # 等于:sess.graph_def
            output_node_names=output_node_names.split(","))  # 如果有多个输出节点,以逗号隔开
        # print(pb_path_def)


        with tf.gfile.GFile(pb_path, 'wb') as fgraph:
            fgraph.write(pb_path_def.SerializeToString())
        # with tf.io.gfile.GFile(pb_path, "wb") as f:  # 保存模型
        #     f.write(pb_path_def.SerializeToString())  # 序列化输出
        print("%d ops in the final graph." % len(pb_path_def.node))  # 得到当前图有几个操作节点

if __name__ == '__main__':
    # 输入路径(cpkt)
    cpkt_path = 'zhou/latest_model.ckpt'
    # 输出路径(pb模型)pb_path_def
    pb_path = "zhou/test.pb"
    # 模型转换
    freeze_graph(cpkt_path, pb_path)

    # # 查看节点名称:
    # reader = pywrap_tensorflow.NewCheckpointReader(cpkt_path)
    # var_to_shape_map = reader.get_variable_to_shape_map()
    # for key in var_to_shape_map:
    #     print("tensor_name: ", key)

    # # 查看某个指定节点的权重
    # reader = pywrap_tensorflow.NewCheckpointReader(cpkt_path)
    # var_to_shape_map = reader.get_variable_to_shape_map()
    # w0 = reader.get_tensor("finetune/dense_1/bias")
    # print(w0.shape, type(w0))
    # print(w0[0])
    

    # with tf.Session() as sess:
    #     # 加载模型定义的graph
    #     saver = tf.train.import_meta_graph('model/model.meta')
    #     # 方式一:加载指定文件夹下最近保存的一个模型的数据
    #     saver.restore(sess, tf.train.latest_checkpoint('model/'))
    #     # 方式二:指定具体某个数据,需要注意的是,指定的文件不要包含后缀
    #     # saver.restore(sess, os.path.join(path, 'model.ckpt-1000'))

    #     # 查看模型中的trainable variables
    #     tvs = [v for v in tf.trainable_variables()]
    #     for v in tvs:
    #         print(v.name)
    #         # print(sess.run(v))

    #     # # 查看模型中的所有tensor或者operations
    #     # gv = [v for v in tf.global_variables()]
    #     # for v in gv:
    #     #     print(v.name)

    #     # # 获得几乎所有的operations相关的tensor
    #     # ops = [o for o in sess.graph.get_operations()]
    #     # for o in ops:
    #     #     print(o.name)

4.pb2pbtxt.py

import tensorflow as tf
from tensorflow.python.platform import gfile
from google.protobuf import text_format


def convert_pb_to_pbtxt(root_path, pb_path, pbtxt_path):
    with gfile.FastGFile(root_path+pb_path, 'rb') as f:
        graph_def = tf.GraphDef()

        graph_def.ParseFromString(f.read())

        tf.import_graph_def(graph_def, name='')

        tf.train.write_graph(graph_def, root_path, pbtxt_path, as_text=True)
    return


def convert_pbtxt_to_pb(root_path, pb_path, pbtxt_path):
    with tf.gfile.FastGFile(root_path+pbtxt_path, 'r') as f:
        graph_def = tf.GraphDef()
        file_content = f.read()

        # Merges the human-readable string in `file_content` into `graph_def`.
        text_format.Merge(file_content, graph_def)
        tf.train.write_graph(graph_def, root_path, pb_path, as_text=False)
    return

if __name__ == '__main__':
    # 模型路径
    root_path = "test1/"
    pb_path = "33.pb"
    pbtxt_path = "33.pbtxt"

    # 模型转换
    convert_pb_to_pbtxt(root_path, pb_path, pbtxt_path)
    # convert_pbtxt_to_pb(root_path, pb_path, pbtxt_path)

5.test_pb.py

import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python import pywrap_tensorflow
import cv2
import numpy as np




def recognize(jpg_path, pb_file_path):
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()

        with open(pb_file_path, "rb") as f:#主要步骤即为以下标出的几步,1、2步即为读取图
            output_graph_def.ParseFromString(f.read())# 1.将模型文件解析为二进制放进graph_def对象
            _ = tf.import_graph_def(output_graph_def, name="")# 2.import到当前图

        with tf.Session() as sess:
            init = tf.global_variables_initializer()
            sess.run(init)

            graph = tf.get_default_graph()# 3.获得当前图
            
            # 4.get_tensor_by_name获取需要的节点
            x = graph.get_tensor_by_name("IteratorGetNext_1:0")
            y_out = graph.get_tensor_by_name("resnet_v1_50_1/predictions/Softmax:0")

            # x = graph.get_tensor_by_name("image_tensor:0")
            # y_out = graph.get_tensor_by_name("Cast:0")
            
            img=np.random.normal(size=(32, 224, 224, 3))
            # img=cv2.imread(jpg_path)
            # img=cv2.resize(img, (224, 224))
            # img=np.reshape(img,(1,224,224,3))
            print(img.shape)
            
            #执行
            output = sess.run(y_out, feed_dict={x:img})
            # pred=np.argmax(output, axis=1)
            # print("预测结果:", output.shape, output, "预测label:", pred)

            # prediction_labels = np.argmax(test_y_out, axis=2)
            # print(prediction_labels.shape, prediction_labels)

recognize("test1/a.jpg", "test1/model2/model.ckpt-15776.pb")

6.pb2tflite.py

import tensorflow as tf
 
in_path = r"resnet50/model.pb"
out_path = r"resnet50/model.tflite"
input_arrays = ["input_images"]
input_shapes = {"input_images" :[1, 28, 28, 1]}
output_arrays = ["resnet_v2_50/logits/BiasAdd"]
 
converter = tf.lite.TFLiteConverter.from_frozen_graph(in_path, input_arrays, output_arrays, input_shapes)
tflite_model = converter.convert()
open(out_path, "wb").write(tflite_model)

7.test_tflite.py

import tensorflow as tf
import numpy as np
import cv2 as cv2


#图片处理,
def image_process(image_path):
    image=cv2.imread(image_path,0)
    image=cv2.resize(image,(28,28))
    image=tf.convert_to_tensor(image)
    image=tf.reshape(image,[1,28,28,1])
    image = tf.cast(image, dtype=np.float32)
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    image = image.eval(session=sess)  # 转化为numpy数组
    return image

def test_tflite(model_path, image):
    # 加载模型
    interpreter = tf.lite.Interpreter(model_path)
    interpreter.allocate_tensors()
    
    # 模型输入和输出细节
    input_details = interpreter.get_input_details()
    # print(input_details)
    output_details = interpreter.get_output_details()
    # print(output_details)
    

    #模型预测
    interpreter.set_tensor(input_details[0]['index'], image)#传入的数据必须为ndarray类型
    interpreter.invoke()
    output_data = interpreter.get_tensor(output_details[0]['index'])
    print(output_data)

    # #标签预测
    # w = np.argmax(output_data)#值最大的位置
    # print(w)

def test_pb(model_path, image):
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()

        with open(model_path, "rb") as f:#主要步骤即为以下标出的几步,1、2步即为读取图
            output_graph_def.ParseFromString(f.read())# 1.将模型文件解析为二进制放进graph_def对象
            _ = tf.import_graph_def(output_graph_def, name="")# 2.import到当前图

        with tf.Session() as sess:
            init = tf.global_variables_initializer()
            sess.run(init)
            graph = tf.get_default_graph()# 3.获得当前图

            x = graph.get_tensor_by_name("input_images:0")
            y_out = graph.get_tensor_by_name("resnet_v2_50/logits/BiasAdd:0")
            
            #执行
            output = sess.run(y_out, feed_dict={x:image})
            print(output)


# #图片传入与处理
image_path='resnet50/img/00001.png'
image=image_process(image_path)

tflite_model="resnet50/model.tflite"
test_tflite(tflite_model, image)

pb_model="resnet50/model.pb"
test_pb(pb_model, image)

8.test_caffe.py

# -*- coding: UTF-8 -*-
import caffe                                                     
import numpy as np
import cv2
from PIL import Image
def test(deploy_proto, caffe_model, img_path):
    #加载model和deploy
    net = caffe.Net(deploy_proto, caffe_model, caffe.TEST)

    # 记载数据
    tets_img=cv2.imread(img_path)
    tets_img=np.transpose(tets_img,(2,0,1))
    tets_img=np.expand_dims(tets_img,0)
    tets_img=tets_img/127.5-1
    # print(tets_img.shape)
    
    # 执行测试
    net.blobs['data'].data[...] = tets_img
    out = net.forward()
    result = net.blobs['conv2'].data[0]

    # 后处理层
    result=result+tets_img[0]
    result=np.clip(result,-1,1)
    
    # 归一化
    img = np.transpose(result,(1,2,0))
    img=(img+1)*127.5
    cv2.imwrite("data/test.jpg",img)
    print(img.shape)

if __name__ == '__main__':
    deploy_proto = "caffe/yolov3_me.prototxt"
    caffe_model = 'caffe/yolov3_me.caffemodel'
    img_path = 'caffe/dog.jpg'

    test(deploy_proto, caffe_model, img_path)

9.h52pb.py

from keras.models import load_model
import tensorflow as tf
import os 
import os.path as osp
from keras import backend as K
#路径参数
input_path = './'
weight_file = 'model/a.h5'
weight_file_path = osp.join(input_path,weight_file)
output_graph_name = weight_file[:-3] + '.pb'
#转换函数
def h5_to_pb(h5_model,output_dir,model_name,out_prefix = "output_",log_tensorboard = True):
    if osp.exists(output_dir) == False:
        os.mkdir(output_dir)
    out_nodes = []
    for i in range(len(h5_model.outputs)):
        out_nodes.append(out_prefix + str(i + 1))
        tf.identity(h5_model.output[i],out_prefix + str(i + 1))
    sess = K.get_session()
    from tensorflow.python.framework import graph_util,graph_io
    init_graph = sess.graph.as_graph_def()
    main_graph = graph_util.convert_variables_to_constants(sess,init_graph,out_nodes)
    graph_io.write_graph(main_graph,output_dir,name = model_name,as_text = False)
    if log_tensorboard:
        from tensorflow.python.tools import import_pb_to_tensorboard
        import_pb_to_tensorboard.import_to_tensorboard(osp.join(output_dir,model_name),output_dir)
#输出路径
output_dir = osp.join(os.getcwd(),"trans_model")
#加载模型
h5_model = load_model(weight_file_path)
# print(h5_model.outputs)
h5_to_pb(h5_model,output_dir = output_dir,model_name = output_graph_name)
print('model saved')

  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值