本章提供了tensorflow模型常用的一些处理方法,包括:
tensorboard查看ckpt网络结构
tensorboard查看pb网络结构
ckpt模型转pb模型
pb模型转pbtxt文件
测试pb模型
pb模型转tflite模型
测试tflite模型
h5模型转pb模型
测试caffe模型
1.1 查看ckpt网络结构_v1.py
运行脚本
在控制台输入命令:tensorboard --logdir=d:/log --host=127.0.0.1
浏览器输入:http://127.0.0.1:6006/
from tensorflow.python import pywrap_tensorflow
import os
import tensorflow as tf
from tensorflow.python.platform import gfile
# checkpoint_path = os.path.join('model/model.ckpt')
# reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
# var_to_shape_map = reader.get_variable_to_shape_map()
# for key in var_to_shape_map:
# print('tensor_name: ', key)
ckpt_path = os.path.join('test1/model/model.ckpt-15776')
saver = tf.train.import_meta_graph(ckpt_path+'.meta',clear_devices=True)
graph = tf.get_default_graph()
with tf.Session( graph=graph) as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess,ckpt_path)
# tensorboard --logdir=d:/log --host=127.0.0.1
init = tf.initialize_all_variables()
merged = tf.summary.merge_all()
writer = tf.summary.FileWriter("d:/log/",sess.graph) #目录结构尽量简单,复杂了容易出现找不到文件,原因不清楚
sess.run(init)
1.2 查看ckpt网络结构_v2.py
import tensorflow as tf
from tensorflow.summary import FileWriter
sess = tf.Session()
tf.train.import_meta_graph('zhou/latest_model.ckpt.meta')
FileWriter("d:/log/", sess.graph)
# tensorboard --logdir=d:/log --host=127.0.0.1
2.查看pb网络结构.py
import tensorflow as tf
from tensorflow.python.platform import gfile
with tf.Session() as sess:
# model_filename ="saved_model/fasterrcnn_resnet101_const.pb"
model_filename ="test1/frozen_inference_graph.pb"
with gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
g_in = tf.import_graph_def(graph_def)
# LOGDIR="D:\\temp\\Russia_FasterRcnn\\log"
# tensorboard --logdir=d:/log --host=127.0.0.1
LOGDIR="d:/log/"
train_writer = tf.summary.FileWriter(LOGDIR)
train_writer.add_graph(sess.graph)
3.ckpt2pb.py
import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python import pywrap_tensorflow
def freeze_graph(cpkt_path, pb_path):
# checkpoint = tf.train.get_checkpoint_state('zhou/') #检查目录下ckpt文件状态是否可用
# cpkt_path2 = checkpoint.model_checkpoint_path #得ckpt文件路径
# cpkt_path3 = checkpoint.all_model_checkpoint_paths
# print("model_pa:",cpkt_path3)
# 指定输出的节点名称,该节点名称必须是原模型中存在的节点
# output_node_names = "num_detections,raw_detection_boxes,raw_detection_scores"
output_node_names = 'logistic_loss/mul'
saver = tf.train.import_meta_graph(cpkt_path + '.meta', clear_devices=True)
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
print(graph)
# feature_data_list = input_graph_def.get_operation_by_name('resnet_v2_50/conv1').outputs[0]
# input_image=tf.placeholder(None,28,28,1)
with tf.Session() as sess:
saver.restore(sess, cpkt_path) # 恢复图并得到数据
pb_path_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
sess=sess,
input_graph_def=input_graph_def, # 等于:sess.graph_def
output_node_names=output_node_names.split(",")) # 如果有多个输出节点,以逗号隔开
# print(pb_path_def)
with tf.gfile.GFile(pb_path, 'wb') as fgraph:
fgraph.write(pb_path_def.SerializeToString())
# with tf.io.gfile.GFile(pb_path, "wb") as f: # 保存模型
# f.write(pb_path_def.SerializeToString()) # 序列化输出
print("%d ops in the final graph." % len(pb_path_def.node)) # 得到当前图有几个操作节点
if __name__ == '__main__':
# 输入路径(cpkt)
cpkt_path = 'zhou/latest_model.ckpt'
# 输出路径(pb模型)pb_path_def
pb_path = "zhou/test.pb"
# 模型转换
freeze_graph(cpkt_path, pb_path)
# # 查看节点名称:
# reader = pywrap_tensorflow.NewCheckpointReader(cpkt_path)
# var_to_shape_map = reader.get_variable_to_shape_map()
# for key in var_to_shape_map:
# print("tensor_name: ", key)
# # 查看某个指定节点的权重
# reader = pywrap_tensorflow.NewCheckpointReader(cpkt_path)
# var_to_shape_map = reader.get_variable_to_shape_map()
# w0 = reader.get_tensor("finetune/dense_1/bias")
# print(w0.shape, type(w0))
# print(w0[0])
# with tf.Session() as sess:
# # 加载模型定义的graph
# saver = tf.train.import_meta_graph('model/model.meta')
# # 方式一:加载指定文件夹下最近保存的一个模型的数据
# saver.restore(sess, tf.train.latest_checkpoint('model/'))
# # 方式二:指定具体某个数据,需要注意的是,指定的文件不要包含后缀
# # saver.restore(sess, os.path.join(path, 'model.ckpt-1000'))
# # 查看模型中的trainable variables
# tvs = [v for v in tf.trainable_variables()]
# for v in tvs:
# print(v.name)
# # print(sess.run(v))
# # # 查看模型中的所有tensor或者operations
# # gv = [v for v in tf.global_variables()]
# # for v in gv:
# # print(v.name)
# # # 获得几乎所有的operations相关的tensor
# # ops = [o for o in sess.graph.get_operations()]
# # for o in ops:
# # print(o.name)
4.pb2pbtxt.py
import tensorflow as tf
from tensorflow.python.platform import gfile
from google.protobuf import text_format
def convert_pb_to_pbtxt(root_path, pb_path, pbtxt_path):
with gfile.FastGFile(root_path+pb_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
tf.train.write_graph(graph_def, root_path, pbtxt_path, as_text=True)
return
def convert_pbtxt_to_pb(root_path, pb_path, pbtxt_path):
with tf.gfile.FastGFile(root_path+pbtxt_path, 'r') as f:
graph_def = tf.GraphDef()
file_content = f.read()
# Merges the human-readable string in `file_content` into `graph_def`.
text_format.Merge(file_content, graph_def)
tf.train.write_graph(graph_def, root_path, pb_path, as_text=False)
return
if __name__ == '__main__':
# 模型路径
root_path = "test1/"
pb_path = "33.pb"
pbtxt_path = "33.pbtxt"
# 模型转换
convert_pb_to_pbtxt(root_path, pb_path, pbtxt_path)
# convert_pbtxt_to_pb(root_path, pb_path, pbtxt_path)
5.test_pb.py
import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python import pywrap_tensorflow
import cv2
import numpy as np
def recognize(jpg_path, pb_file_path):
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with open(pb_file_path, "rb") as f:#主要步骤即为以下标出的几步,1、2步即为读取图
output_graph_def.ParseFromString(f.read())# 1.将模型文件解析为二进制放进graph_def对象
_ = tf.import_graph_def(output_graph_def, name="")# 2.import到当前图
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
graph = tf.get_default_graph()# 3.获得当前图
# 4.get_tensor_by_name获取需要的节点
x = graph.get_tensor_by_name("IteratorGetNext_1:0")
y_out = graph.get_tensor_by_name("resnet_v1_50_1/predictions/Softmax:0")
# x = graph.get_tensor_by_name("image_tensor:0")
# y_out = graph.get_tensor_by_name("Cast:0")
img=np.random.normal(size=(32, 224, 224, 3))
# img=cv2.imread(jpg_path)
# img=cv2.resize(img, (224, 224))
# img=np.reshape(img,(1,224,224,3))
print(img.shape)
#执行
output = sess.run(y_out, feed_dict={x:img})
# pred=np.argmax(output, axis=1)
# print("预测结果:", output.shape, output, "预测label:", pred)
# prediction_labels = np.argmax(test_y_out, axis=2)
# print(prediction_labels.shape, prediction_labels)
recognize("test1/a.jpg", "test1/model2/model.ckpt-15776.pb")
6.pb2tflite.py
import tensorflow as tf
in_path = r"resnet50/model.pb"
out_path = r"resnet50/model.tflite"
input_arrays = ["input_images"]
input_shapes = {"input_images" :[1, 28, 28, 1]}
output_arrays = ["resnet_v2_50/logits/BiasAdd"]
converter = tf.lite.TFLiteConverter.from_frozen_graph(in_path, input_arrays, output_arrays, input_shapes)
tflite_model = converter.convert()
open(out_path, "wb").write(tflite_model)
7.test_tflite.py
import tensorflow as tf
import numpy as np
import cv2 as cv2
#图片处理,
def image_process(image_path):
image=cv2.imread(image_path,0)
image=cv2.resize(image,(28,28))
image=tf.convert_to_tensor(image)
image=tf.reshape(image,[1,28,28,1])
image = tf.cast(image, dtype=np.float32)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
image = image.eval(session=sess) # 转化为numpy数组
return image
def test_tflite(model_path, image):
# 加载模型
interpreter = tf.lite.Interpreter(model_path)
interpreter.allocate_tensors()
# 模型输入和输出细节
input_details = interpreter.get_input_details()
# print(input_details)
output_details = interpreter.get_output_details()
# print(output_details)
#模型预测
interpreter.set_tensor(input_details[0]['index'], image)#传入的数据必须为ndarray类型
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
# #标签预测
# w = np.argmax(output_data)#值最大的位置
# print(w)
def test_pb(model_path, image):
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with open(model_path, "rb") as f:#主要步骤即为以下标出的几步,1、2步即为读取图
output_graph_def.ParseFromString(f.read())# 1.将模型文件解析为二进制放进graph_def对象
_ = tf.import_graph_def(output_graph_def, name="")# 2.import到当前图
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
graph = tf.get_default_graph()# 3.获得当前图
x = graph.get_tensor_by_name("input_images:0")
y_out = graph.get_tensor_by_name("resnet_v2_50/logits/BiasAdd:0")
#执行
output = sess.run(y_out, feed_dict={x:image})
print(output)
# #图片传入与处理
image_path='resnet50/img/00001.png'
image=image_process(image_path)
tflite_model="resnet50/model.tflite"
test_tflite(tflite_model, image)
pb_model="resnet50/model.pb"
test_pb(pb_model, image)
8.test_caffe.py
# -*- coding: UTF-8 -*-
import caffe
import numpy as np
import cv2
from PIL import Image
def test(deploy_proto, caffe_model, img_path):
#加载model和deploy
net = caffe.Net(deploy_proto, caffe_model, caffe.TEST)
# 记载数据
tets_img=cv2.imread(img_path)
tets_img=np.transpose(tets_img,(2,0,1))
tets_img=np.expand_dims(tets_img,0)
tets_img=tets_img/127.5-1
# print(tets_img.shape)
# 执行测试
net.blobs['data'].data[...] = tets_img
out = net.forward()
result = net.blobs['conv2'].data[0]
# 后处理层
result=result+tets_img[0]
result=np.clip(result,-1,1)
# 归一化
img = np.transpose(result,(1,2,0))
img=(img+1)*127.5
cv2.imwrite("data/test.jpg",img)
print(img.shape)
if __name__ == '__main__':
deploy_proto = "caffe/yolov3_me.prototxt"
caffe_model = 'caffe/yolov3_me.caffemodel'
img_path = 'caffe/dog.jpg'
test(deploy_proto, caffe_model, img_path)
9.h52pb.py
from keras.models import load_model
import tensorflow as tf
import os
import os.path as osp
from keras import backend as K
#路径参数
input_path = './'
weight_file = 'model/a.h5'
weight_file_path = osp.join(input_path,weight_file)
output_graph_name = weight_file[:-3] + '.pb'
#转换函数
def h5_to_pb(h5_model,output_dir,model_name,out_prefix = "output_",log_tensorboard = True):
if osp.exists(output_dir) == False:
os.mkdir(output_dir)
out_nodes = []
for i in range(len(h5_model.outputs)):
out_nodes.append(out_prefix + str(i + 1))
tf.identity(h5_model.output[i],out_prefix + str(i + 1))
sess = K.get_session()
from tensorflow.python.framework import graph_util,graph_io
init_graph = sess.graph.as_graph_def()
main_graph = graph_util.convert_variables_to_constants(sess,init_graph,out_nodes)
graph_io.write_graph(main_graph,output_dir,name = model_name,as_text = False)
if log_tensorboard:
from tensorflow.python.tools import import_pb_to_tensorboard
import_pb_to_tensorboard.import_to_tensorboard(osp.join(output_dir,model_name),output_dir)
#输出路径
output_dir = osp.join(os.getcwd(),"trans_model")
#加载模型
h5_model = load_model(weight_file_path)
# print(h5_model.outputs)
h5_to_pb(h5_model,output_dir = output_dir,model_name = output_graph_name)
print('model saved')