1 神经网络结构
1.0 保存*.pb模型
import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
import numpy as np
import os
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
from mpl_toolkits.mplot3d import Axes3D
# Ubuntu system font path
font = FontProperties(fname='/usr/share/fonts/truetype/arphic/ukai.ttc')
MODEL_SAVE_PATH = "./models"
MODEL_NAME_meta = "nn_model.ckpt"
MODEL_NAME_pb = "nn_model.pb"
LOG_DIR = "./logs/NNmergelog"
'''Simulation datas.'''
x_data = np.linspace(-1, 1, 250, dtype=np.float32)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape).astype(np.float32)
y_data = np.square(x_data) - 0.5*x_data + noise
'''Neural network structure.'''
input_size_1 = 1
output_size_1 = 10
input_size_2 = 10
output_size_2 = 1
def saved_model_pb():
'''Input layer.'''
with tf.name_scope("Input"):
xs = tf.placeholder(tf.float32, [None, 1], name='x')
ys = tf.placeholder(tf.float32, [None, 1], name='y')
'''Hidden layer.'''
with tf.name_scope("Layer1"):
weights_1 = tf.Variable(tf.random_normal([input_size_1, output_size_1]), name='weights_1')
biases_1 = tf.Variable(tf.zeros([1, output_size_1]), name='biases_1')
layer_1 = tf.nn.relu(tf.matmul(xs, weights_1) + biases_1)
tf.summary.histogram('weights_1', weights_1)
tf.summary.histogram('biases_1', biases_1)
tf.summary.histogram('layer_1', layer_1)
'''Ouptput Layer.'''
with tf.name_scope("Output"):
weights_2 = tf.Variable(tf.random_normal([input_size_2, output_size_2]), name='weights_2')
biases_2 = tf.Variable(tf.zeros([1, output_size_2]), name='biases_2')
outputs_2 = tf.matmul(layer_1, weights_2)
prediction = tf.add(outputs_2, biases_2, name="predictions")
tf.summary.histogram('weights_2', weights_2)
tf.summary.histogram('biases_2', biases_2)
tf.summary.histogram('prediction', prediction)
'''Loss function.'''
with tf.name_scope("Loss"):
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction), reduction_indices=[1]))
tf.summary.scalar('loss', loss)
tf.summary.histogram('loss', loss)
'''Optimizer the loss.'''
with tf.name_scope("Train_Step"):
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
'''Merge all the summary up used.'''
merged = tf.summary.merge_all()
'''Save Model.'''
with tf.Session() as sess:
'''Initializer varabiles and log defined in Tensorflow.'''
summary_writer = tf.summary.FileWriter(LOG_DIR, sess.graph)
init_op = tf.global_variables_initializer()
sess.run(init_op)
a = 0
for i in range(301):
'''Convert nodes to constant in models by name.'''
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["Input/x", "Input/y", "Output/predictions"])
'''Compute the nodes value and save the log file.'''
summary, train_step_value, loss_value, pre = sess.run([merged, train_step, loss, prediction], feed_dict={xs: x_data, ys: y_data})
if i % 50 == 0:
'''Output train effects in every 50 steps.'''
a += 1
w1 = sess.run(weights_1)
w2 = sess.run(weights_2)
print("Weights_1 :{}".format(w1))
print("weights_2 :{}".format(w2))
# print(a)
# loss_1 = sess.run(loss, feed_dict={xs: x_data, ys: y_data})
print("Loss :{}".format(loss_value))
print(prediction)
print(loss)
print(train_step_value)
'''Write the model parameters in specify files we are defined.'''
with tf.gfile.FastGFile(os.path.join(MODEL_SAVE_PATH, MODEL_NAME_pb), mode="wb") as f:
f.write(constant_graph.SerializeToString())
'''Summary total logs in files.'''
summary_writer.add_summary(summary, i)
summary_writer.close()
1.2 载入*.pb模型
def load_pb_model():
with tf.Session() as sess:
'''Input data for evaluate the model.'''
x_data = np.linspace(-1, 1, 250, dtype=np.float32)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape).astype(np.float32)
y_data = np.square(x_data) - 0.5*x_data + noise
'''Load model from *.pb'''
with gfile.FastGFile("./models/nn_model.pb", "rb") as f:
new_graph = tf.GraphDef()
new_graph.ParseFromString(f.read())
tf.import_graph_def(new_graph, name='')
'''
Get default graph structure, which operation must be
after loaded the modelself.
'''
g = tf.get_default_graph()
'''
Get tensor by name in graph we defined,
we use the variable scope or name scope,
thus we need append the name prefix before load node names.
'''
pre = g.get_tensor_by_name("Output/predictions:0")
x = g.get_tensor_by_name("Input/x:0")
'''Compute the prediction value by loading the trained model.'''
pre = sess.run(pre, feed_dict={x: x_data})
plt.figure(figsize=(6, 6))
plt.plot(x_data, pre, label="预测结果")
plt.grid()
plt.xlabel("x轴", fontproperties=font)
plt.ylabel("y轴", fontproperties=font)
plt.scatter(x_data, y_data, s=10, c="r", marker="*", label="实际值")
plt.legend(prop=font)
'''Save and show image.'''
plt.savefig("./images/pb_load.png", format="png")
plt.show()
2 结果
2.1 训练结果
权重与偏置项结果.
Weights_1 :[[-1.7107134 0.5941573 0.37450954 0.53004044 0.3793792 0.9144222
-1.5825071 -0.6608934 -0.96931577 0.5307749 ]]
weights_2 :[[ 0.31370506]
[-1.4543793 ]
[-1.9223864 ]
[ 0.14437917]
[-1.1137098 ]
[ 0.05373428]
[ 0.6884544 ]
[-0.01735083]
[ 0.03221066]
[ 1.2156694 ]]
Loss :0.004902012180536985
Tensor("Output/add:0", shape=(?, 1), dtype=float32)
Tensor("Loss/Mean:0", shape=(), dtype=float32)
None
2.2 载入模型验证结果
3 java调用pb模型
3.1 pom.xml
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.5.0</version>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.6</version>
</dependency>
3.2 控制层
package com.sb.controller;
import org.apache.commons.io.IOUtils;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.util.ResourceUtils;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.File;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@CrossOrigin(origins="*", maxAge=3600)
@RestController
@RequestMapping("/api/ai")
public class AIController{
static Logger logger = LoggerFactory.getLogger(AIController.class);
@RequestMapping(value="pre", method=RequestMethod.POST)
public Float predictionTest(@RequestBody Map datas) throws Exception{
float input = Float.parseFloat(datas.get("input").toString());
// float[][] x = new float[1][1];
// x[0] = new float[]{1.0f};
float[][] x = {{input}};
logger.info("input:{}", x);
File file = ResourceUtils.getFile("classpath:model/nn_model.pb");
try(Graph graph = new Graph()){
byte[] graphBytes = IOUtils.toByteArray(new FileInputStream(file));
graph.importGraphDef(graphBytes);
try(Session session = new Session(graph)){
// python 执行调用模型计算:pre = sess.run(pre, feed_dict={x: x_data})
// pre = session.runner()
// .feed("Input/x:0", Tensor.create(x))
// .fetch("Output/predictions:0").run().get(0).floatValue();
// 获取模型生成的Tensor,使用floatValue报错:java.lang.Illeagal: Tensor is not a scalar
Tensor prediction = session.runner()
.feed("Input/x:0", Tensor.create(x))
.fetch("Output/predictions:0").run().get(0);
float[][] preOutput = (float[][])prediction.copyTo(new float[1][1]);
return preOutput[0][0];
}
}catch(Exception e){
e.printStackTrace();
}
return input;
}
}
4 总结
- *.pb模型文件具有语言独立性,可独立运行,封闭的序列化格式,可使用任何语言解析。
- *.pb模型文件中变量是固定的(const)即模型中的变量值固定存储。
- *.pb模型文件使用过程中不会重新“学习”,即模型参数不变,保证了模型的稳定性。
- *.pb模型文件实现了庞大模型的瘦身,即该格式的模型尺寸较小主要用于移动端。
基础阅读:
[1](一)Tensorflow搭建神经网络
[2](二)Tensorflow神经网络保存模型(持久化)
[3](三)Tensorflow神经网络之模型载入及迁移学习
[参考文献]
[1]https://tensorflow.google.cn/versions/r1.12/api_docs/python/tf/gfile/FastGFile
[2]https://blog.csdn.net/fu6543210/article/details/80343345
[3]https://blog.csdn.net/wshzd/article/details/88840792