tensorflow:一个简单的python训练保存模型,java还原模型方法

总结一下这段时间学习使用tensorflow的一些经验。主要应用场景是,使用python语言训练一个简单的LR模型,并且将模型以savedModel格式保存模型,然后以python和java语言还原模型,预测结果。

python tensorflow版本和java client版本要对应,通过 tf.__version__ 查看python的TensorFlow版本,java client版本就是jar包版本。

(1)训练模型

import tensorflow as tf
import numpy as np

#生成训练数据
x = np.ndarray(dtype=np.float32, shape=[4, 2])
x[0] = [1,1]
x[1] = [1,2]
x[2] = [1,3]
x[3] = [2,4]
print('====================')
print(x)
print(x.shape)
print(x.dtype)

#创建placeHolder作为输入
x_inputs = tf.placeholder(tf.float32, shape=[None, 2])

#输出结果
y_true = tf.constant([[2], [4], [5], [9]], dtype=tf.float32)

#单层神经网络,搭建LR模型
linear_model = tf.layers.Dense(units=1)
y_pred = linear_model(x_inputs)

#构建session
sess = tf.Session()
#保存模型tensorbord可视化结构的writer
writer = tf.summary.FileWriter("/Users/yourName/pythonworkspace/tmp/log", sess.graph)
#初始化变量
init = tf.global_variables_initializer()
sess.run(init)
#构建损失函数
loss = tf.losses.mean_squared_error(labels=y_true, predictions=y_pred)

#梯度下降优化器
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)

#开始训练模型
print('================start=================')
for i in range(10000):
    _, loss_value = sess.run((train, loss), feed_dict={x_inputs:x})
    if i % 1000 == 0:
        print(loss_value)

#关闭可视化writer,可以通过tensorboard --logdir /Users/yourName/pythonworkspace/tmp/log加载可视化模型
writer.close()

#构建savedModel构建器
builder = tf.saved_model.builder.SavedModelBuilder("/Users/yourName/pythonworkspace/tmp/savedModel/lrmodel")
# x 为输入tensor, keep_prob为dropout的prob tensor
inputs = {'input': tf.saved_model.utils.build_tensor_info(x_inputs)}

# y 为最终需要的输出结果tensor
outputs = {'output': tf.saved_model.utils.build_tensor_info(y_pred)}

signature = tf.saved_model.signature_def_utils.build_signature_def(inputs, outputs, 'test_sig_name')
#保存模型
builder.add_meta_graph_and_variables(sess, ['test_saved_model'], {'test_signature':signature})
builder.save()

(2)python 加载模型

import tensorflow as tf
with tf.Session(graph=tf.Graph()) as sess:
  #加载模型
  meta_graph_def = tf.saved_model.loader.load(sess, ['test_saved_model'], "/Users/yourName/pythonworkspace/tmp/savedModel/lrmodel")
  #加载模型签名
  signature = meta_graph_def.signature_def
  print(signature)
  #从签名中获得张量名
  y_tensor_name = signature['test_signature'].outputs['output'].name
  x_tensor_name = signature['test_signature'].inputs['input'].name
  print(y_tensor_name)
  print(x_tensor_name)
  #还原张量
  y_pred = sess.graph.get_tensor_by_name(y_tensor_name)
  x_inputs = sess.graph.get_tensor_by_name(x_tensor_name)

  # 预测结果
  print(sess.run(y_pred, feed_dict={x_inputs:[[1,6]]}))

(3)java 加载模型
加载tensorflow依赖包

 <dependencies>
        <!-- https://mvnrepository.com/artifact/org.tensorflow/tensorflow -->
        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow</artifactId>
            <version>1.8.0-rc0</version>
        </dependency>
        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>proto</artifactId>
            <version>1.8.0-rc1</version>
        </dependency>
    </dependencies>

加载模型代码

import com.google.protobuf.InvalidProtocolBufferException;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;

import java.util.List;


public class Test {
    public static void main(String[] args) throws InvalidProtocolBufferException {

        /*加载模型 */
        SavedModelBundle savedModelBundle = SavedModelBundle.load("/Users/yourName/pythonworkspace/tmp/savedModel/lrmodel", "test_saved_model");
        /*构建预测张量*/
        float[][] matrix = new float[1][2];
        matrix[0][0] = 1;
        matrix[0][1] = 6;
        Tensor<Float> x = Tensor.create(matrix, Float.class);
        /*获取模型签名*/
        SignatureDef sig = MetaGraphDef.parseFrom(savedModelBundle.metaGraphDef()).getSignatureDefOrThrow("test_signature");
        String inputName = sig.getInputsMap().get("input").getName();
        System.out.println(inputName);
        String outputName = sig.getOutputsMap().get("output").getName();
        System.out.println(outputName);
        /*预测模型结果*/
        List<Tensor<?>> y = savedModelBundle.session().runner().feed(inputName, x).fetch(outputName).run();
        float [][] result = new float[1][1];
        System.out.println(y.get(0).dataType());
        System.out.println(y.get(0).copyTo(result));
        System.out.println(result[0][0]);

    }
}
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

WitsMakeMen

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值