python预测模型导出_Tensorflow导出pb模型,并在python和matlab下分别进行预测

tensorflow下训练完模型测试程序比较杂乱,特此整理一下。

1、我都是在linux下训练,windows下调用测试,训练保存模型如下所示。

2、然后调用frozen_model.py将模型进行固化,这里需要注意一点就是网络输出结点的名称,可以在tensorboard查看GRAPHS中网络输出结点名或训练时进行命名。

frozen_model.py

import tensorflow as tf

from tensorflow.python.framework import graph_util

def freeze_graph(input_checkpoint,output_graph):

# 原模型中输出节点名称

output_node_names = "generator/decoder_1/output_node"

saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)

graph = tf.get_default_graph()

input_graph_def = graph.as_graph_def()

with tf.Session() as sess:

saver.restore(sess, input_checkpoint)

output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定

sess=sess,

input_graph_def=input_graph_def,# 等于:sess.graph_def

output_node_names=output_node_names.split(","))

with tf.gfile.GFile(output_graph, "wb") as f:

f.write(output_graph_def.SerializeToString())

print("%d ops in the final graph." % len(output_graph_def.node))

tf.reset_default_graph()

input_checkpoint = "tensorflow_model/spot_train/model-500"

output_graph = "frozen_model/frozen_model.pb"

freeze_graph(input_checkpoint,output_graph)

3、得到pb模型后,调用test.py进行测试。需要注意输入输出tensor名字一定要写对,一般结点名字后面加":0"就是对应tensor名,可以在这个网站打开pb模型查看tensor名

test.py

#-*- coding:utf-8 -*-

import os

import tensorflow as tf

from tensorflow.python.framework import graph_util

import numpy as np

import scipy.io

from tensorflow.python.platform import gfile

tf.reset_default_graph()

pb_file_path = 'model_package/frozen_model/'

result_file_path = 'test_results/'

def preprocess(x):

Max = np.max(x)

Min = np.min(x)

x = (x-Min)/(Max-Min)

return x*2-1

def deprocess(x):

return (x+1)/2

data = scipy.io.loadmat('1.mat')['data']

flatten_img = preprocess(np.reshape(data, [1, 256,256,1]))

sess = tf.Session()

with gfile.FastGFile(pb_file_path + 'frozen_model.pb', 'rb') as f: #加载模型

graph_def = tf.GraphDef()

graph_def.ParseFromString(f.read())

sess.graph.as_default()

tf.import_graph_def(graph_def, name='') # 导入计算图

# 初始化

#sess.run(tf.global_variables_initializer())

x = sess.graph.get_tensor_by_name('batch:1')

y = sess.graph.get_tensor_by_name('generator/decoder_1/output_node:0')

y_out=sess.run(y,feed_dict={x:flatten_img})

scipy.io.savemat(result_file_path+'test.mat', {'output':y_out})

后面为方便matlab调用,又整理成类了

python_test.py

#-*- coding:utf-8 -*-

import os

import tensorflow as tf

from tensorflow.python.framework import graph_util

import numpy as np

import scipy.io

from tensorflow.python.platform import gfile

from glob import glob

#pb_file_path = 'model_package/frozen_model/'

class Predict(object):

def __init__(self):

tf.reset_default_graph()

self.sess = tf.Session()

with gfile.FastGFile('frozen_model/frozen_model.pb', 'rb') as f: #加载模型

graph_def = tf.GraphDef()

graph_def.ParseFromString(f.read())

sess.graph.as_default()

tf.import_graph_def(graph_def, name='') #导入计算图

print('load model success')

#sess.run(tf.global_variables_initializer()) #初始化

self.x = sess.graph.get_tensor_by_name('batch:1')

self.y = sess.graph.get_tensor_by_name('generator/decoder_1/output_node:0')

def preprocess(x):

Max = np.max(x)

Min = np.min(x)

x = (x-Min)/(Max-Min)

return x*2-1

def deprocess(x):

return (x+1)/2

def predict(self,input_path):

#加载测试输出数据

img= scipy.io.loadmat(input_path)['data']

flatten_img = preprocess(np.reshape(img, [1, 256,256,1]))

y_out = sess.run(self.y,feed_dict={self.x:flatten_img})

y_out = np.squeeze(deprocess(y_out))

scipy.io.savemat('test_results/'+input_path[17:], {'output':y_out})

if __name__ == '__main__':

model = Predict()

model.predict("1.mat")

4、Matlab中测试采用的是调用python测试脚本还实现的。

clear;close all;clc

clear classes

tf = py.importlib.import_module('tensorflow');

np = py.importlib.import_module('numpy');

%plt = py.importlib.import_module('matplotlib.pyplot');

sio = py.importlib.import_module('scipy.io');

obj = py.importlib.import_module('python_test'); %python测试脚本路径

py.importlib.reload(obj);

a = py.pix2pix_test.Predict();

a.predict('test_data/2.bmp')

a.predict('test_data/3.bmp')

写完回头看了一眼,咋这么混乱都没说清也就自己能看懂了,先这样吧后面要提高一下写作能力了,再接再厉,加油!

本文地址:https://blog.csdn.net/megaoliyuanzhende/article/details/107365281

如您对本文有疑问或者有任何想说的,请点击进行留言回复,万千网友为您解惑!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值