安装好C++版的TensorFlow之后,我们就可以用C++来部署python训练好的TensorFlow模型了。安装C++版的TensorFlow的教程可以参考这里。部署TensorFlow模型主要分为两步,第一步是用python训练模型,然后保存模型为.pb格式的二进制文件;第二步则是在C++中加载python保存的模型并进行预测。
1、python训练模型并保存
这里我们以mnist数据集为例,训练一个三层的多层感知机对手写数字图片进行识别,具体的训练代码如下:
#coding=utf-8
import pickle
import numpy as np
import tensorflow as tf
import time
import os
from tensorflow.python.framework.graph_util import convert_variables_to_constants
# 定义一个mnist数据集的类
class mnistReader():
def __init__(self, mnistPath, onehot = True):
self.mnistPath = mnistPath
self.onehot = onehot
self.batch_index = 0
print ('read:',self.mnistPath)
fo = open(self.mnistPath, 'rb')
self.train_set,self.valid_set,self.test_set = pickle.load(fo, encoding='bytes')
fo.close()
self.data_label_train = list(zip(self.train_set[0], self.train_set[1]))
np.random.shuffle(self.data_label_train)
# 获取下一个训练集的batch
def next_train_batch(self, batch_size = 100):
if self.batch_index < len(self.data_label_train)/batch_size:
print ("batch_index:",self.batch_index )
datum = self.data_label_train[self.batch_index*batch_size:(self.batch_index+1)*batch_size]
self.batch_index+=1
return self._decode(datum, self.onehot)
else:
self.batch_index=0
np.random.shuffle(self.data_label_train)
datum=self.data_label_train[self.batch_index*batch_size:(self.batch_index+1)*batch_size]
self.batch_index+=1
return self._decode(datum,self.onehot)
# 获取测试集的数据
def test_data(self):
tdata, tlabel = self.test_set
data_label_test=list(zip(tdata,tlabel))
return self._decode(data_label_test,self.onehot)
# 把一个batch的训练数据转换为可以放入模型训练的数据
def _decode(self, datum, onehot):
rdata=list()
rlabel=list()
if onehot:
for d,l in datum:
img = np.reshape(d, (28,28))
img = np.expand_dims(img, 2)
rdata.append(img)
hot = np.zeros(10)
hot[int(l)] = 1
rlabel.append(hot)
else:
for d,l in datum:
img = np.reshape(d, (28,28))
img = np.expand_dims(img,2)
rdata.append(img)
rlabel.append(int(l))
return rdata,rlabel
#多层感知机模型(只有一个隐藏层)
def multi_perceptron():
batch_size = 100 # batch大小
height = 28 # 图片高度
width = 28 # 图片宽度
channel = 1 # 图片通道数
in_units = 784 # 多层感知机的输入
h1_units=300 # MLP隐藏层的输出节点数
mnist_path = "E:/testdata/mnist.pkl" # mnist数据集路径
save_path = "./output" #保存模型的路径
images = tf.placeholder(tf.float32, shape = [None, height, width, channel],name = "images")
labels = tf.placeholder(tf.float32,[None, 10], name = "labels")
w1=tf.Variable(tf.truncated_normal([i