Springboot2模块系列:tensorflow(载入pb模型)

1 神经网络结构

1.0 保存*.pb模型

import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
import numpy as np
import os

import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
from mpl_toolkits.mplot3d import Axes3D
# Ubuntu system font path
font = FontProperties(fname='/usr/share/fonts/truetype/arphic/ukai.ttc')

MODEL_SAVE_PATH = "./models"
MODEL_NAME_meta = "nn_model.ckpt"
MODEL_NAME_pb = "nn_model.pb"
LOG_DIR = "./logs/NNmergelog"
'''Simulation datas.'''
x_data = np.linspace(-1, 1, 250, dtype=np.float32)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape).astype(np.float32)
y_data = np.square(x_data) - 0.5*x_data + noise

'''Neural network structure.'''
input_size_1 = 1
output_size_1 = 10

input_size_2 = 10
output_size_2 = 1
def saved_model_pb():
	'''Input layer.'''
	with tf.name_scope("Input"):
		xs = tf.placeholder(tf.float32, [None, 1], name='x')
		ys = tf.placeholder(tf.float32, [None, 1], name='y')
	'''Hidden layer.'''
	with tf.name_scope("Layer1"):
		weights_1 = tf.Variable(tf.random_normal([input_size_1, output_size_1]), name='weights_1')
		biases_1 = tf.Variable(tf.zeros([1, output_size_1]), name='biases_1')
		layer_1 = tf.nn.relu(tf.matmul(xs, weights_1) + biases_1)
		tf.summary.histogram('weights_1', weights_1)
		tf.summary.histogram('biases_1', biases_1)
		tf.summary.histogram('layer_1', layer_1)

	'''Ouptput Layer.'''
	with tf.name_scope("Output"):
		weights_2 = tf.Variable(tf.random_normal([input_size_2, output_size_2]), name='weights_2')
		biases_2 = tf.Variable(tf.zeros([1, output_size_2]), name='biases_2')
		outputs_2 = tf.matmul(layer_1, weights_2)
		prediction = tf.add(outputs_2, biases_2, name="predictions")
		tf.summary.histogram('weights_2', weights_2)
		tf.summary.histogram('biases_2', biases_2)
		tf.summary.histogram('prediction', prediction)

	'''Loss function.'''
	with tf.name_scope("Loss"):
		loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction), reduction_indices=[1]))
		tf.summary.scalar('loss', loss)
		tf.summary.histogram('loss', loss)

	'''Optimizer the loss.'''
	with tf.name_scope("Train_Step"):
		train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

	'''Merge all the summary up used.'''
	merged = tf.summary.merge_all()
	'''Save Model.'''
	with tf.Session() as sess:
		'''Initializer varabiles and log defined in Tensorflow.'''
		summary_writer = tf.summary.FileWriter(LOG_DIR, sess.graph)
		init_op = tf.global_variables_initializer()
		sess.run(init_op)
		a = 0
		for i in range(301):
			'''Convert nodes to constant in models by name.'''
			constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["Input/x", "Input/y", "Output/predictions"])
			'''Compute the nodes value and save the log file.'''
			summary, train_step_value, loss_value, pre = sess.run([merged, train_step, loss, prediction], feed_dict={xs: x_data, ys: y_data})
			if i % 50 == 0:
				'''Output train effects in every 50 steps.'''
				a += 1
				w1 = sess.run(weights_1)
				w2 = sess.run(weights_2)
				print("Weights_1 :{}".format(w1))
				print("weights_2 :{}".format(w2))
				# print(a)
				# loss_1 = sess.run(loss, feed_dict={xs: x_data, ys: y_data})
				print("Loss :{}".format(loss_value))
				print(prediction)
				print(loss)
				print(train_step_value)
			'''Write the model parameters in specify files we are defined.'''
			with tf.gfile.FastGFile(os.path.join(MODEL_SAVE_PATH, MODEL_NAME_pb), mode="wb") as f:
				f.write(constant_graph.SerializeToString())
			'''Summary total logs in files.'''
			summary_writer.add_summary(summary, i)
		summary_writer.close()

1.2 载入*.pb模型

def load_pb_model():
	with tf.Session() as sess:
		'''Input data for evaluate the model.'''
		x_data = np.linspace(-1, 1, 250, dtype=np.float32)[:, np.newaxis]
		noise = np.random.normal(0, 0.05, x_data.shape).astype(np.float32)
		y_data = np.square(x_data) - 0.5*x_data + noise
		'''Load model from *.pb'''
		with gfile.FastGFile("./models/nn_model.pb", "rb") as f:
			new_graph = tf.GraphDef()
			new_graph.ParseFromString(f.read())
			tf.import_graph_def(new_graph, name='')
		'''
		Get default graph structure, which operation must be
		after loaded the modelself.
		'''
		g = tf.get_default_graph()
		'''
		Get tensor by name in graph we defined,
		we use the variable scope or name scope,
		thus we need append the name prefix before load node names.
		'''
		pre = g.get_tensor_by_name("Output/predictions:0")
		x = g.get_tensor_by_name("Input/x:0")
		'''Compute the prediction value by loading the trained model.'''
		pre = sess.run(pre, feed_dict={x: x_data})
		plt.figure(figsize=(6, 6))
		plt.plot(x_data, pre, label="预测结果")
		plt.grid()
		plt.xlabel("x轴", fontproperties=font)
		plt.ylabel("y轴", fontproperties=font)
		plt.scatter(x_data, y_data, s=10, c="r", marker="*", label="实际值")
		plt.legend(prop=font)
	'''Save and show image.'''
	plt.savefig("./images/pb_load.png", format="png")
	plt.show()

2 结果

2.1 训练结果

权重与偏置项结果.

Weights_1 :[[-1.7107134   0.5941573   0.37450954  0.53004044  0.3793792   0.9144222
  -1.5825071  -0.6608934  -0.96931577  0.5307749 ]]
weights_2 :[[ 0.31370506]
 [-1.4543793 ]
 [-1.9223864 ]
 [ 0.14437917]
 [-1.1137098 ]
 [ 0.05373428]
 [ 0.6884544 ]
 [-0.01735083]
 [ 0.03221066]
 [ 1.2156694 ]]
Loss :0.004902012180536985
Tensor("Output/add:0", shape=(?, 1), dtype=float32)
Tensor("Loss/Mean:0", shape=(), dtype=float32)
None

2.2 载入模型验证结果

在这里插入图片描述

图2.1 预测值与理论值

3 java调用pb模型

3.1 pom.xml

<dependency>
      <groupId>org.tensorflow</groupId>
      <artifactId>tensorflow</artifactId>
      <version>1.5.0</version>
    </dependency>
    <dependency>
      <groupId>commons-io</groupId>
      <artifactId>commons-io</artifactId>
      <version>2.6</version>
    </dependency>

3.2 控制层

package com.sb.controller;

import org.apache.commons.io.IOUtils;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.util.ResourceUtils;

import java.io.FileInputStream;
import java.io.IOException;
import java.io.File;
import java.util.Map;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@CrossOrigin(origins="*", maxAge=3600)
@RestController
@RequestMapping("/api/ai")
public class AIController{
    static Logger logger = LoggerFactory.getLogger(AIController.class);
    @RequestMapping(value="pre", method=RequestMethod.POST) 
    public Float predictionTest(@RequestBody Map datas) throws Exception{

        float input = Float.parseFloat(datas.get("input").toString());
        // float[][] x = new float[1][1];
        // x[0] = new float[]{1.0f};
        float[][] x = {{input}};
        logger.info("input:{}", x);

        File file = ResourceUtils.getFile("classpath:model/nn_model.pb");
        try(Graph graph = new Graph()){
            byte[] graphBytes = IOUtils.toByteArray(new FileInputStream(file));
            graph.importGraphDef(graphBytes);
            try(Session session = new Session(graph)){
                // python 执行调用模型计算:pre = sess.run(pre, feed_dict={x: x_data})
                // pre = session.runner()
                //                     .feed("Input/x:0", Tensor.create(x))
                //                     .fetch("Output/predictions:0").run().get(0).floatValue();
                // 获取模型生成的Tensor,使用floatValue报错:java.lang.Illeagal: Tensor is not a scalar
                Tensor prediction = session.runner()
                                .feed("Input/x:0", Tensor.create(x))
                                .fetch("Output/predictions:0").run().get(0);
                float[][] preOutput = (float[][])prediction.copyTo(new float[1][1]);
                return preOutput[0][0];

            }
            
        }catch(Exception e){
            throw new RuntimeException(e);
        }
        return input;
        
    }
}

4 总结

  • *.pb模型文件具有语言独立性,可独立运行,封闭的序列化格式,可使用任何语言解析。
  • *.pb模型文件中变量是固定的(const)即模型中的变量值固定存储。
  • *.pb模型文件使用过程中不会重新“学习”,即模型参数不变,保证了模型的稳定性。
  • *.pb模型文件实现了庞大模型的瘦身,即该格式的模型尺寸较小主要用于移动端。

基础阅读:
[1](一)Tensorflow搭建神经网络
[2](二)Tensorflow神经网络保存模型(持久化)
[3](三)Tensorflow神经网络之模型载入及迁移学习


[参考文献]
[1]https://tensorflow.google.cn/versions/r1.12/api_docs/python/tf/gfile/FastGFile
[2]https://blog.csdn.net/fu6543210/article/details/80343345
[3]https://blog.csdn.net/wshzd/article/details/88840792


  • 1
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

天然玩家

坚持才能做到极致

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值