自己在使用tensorflow的时候,想要保存下训练好的的模型,以供下次使用。到网上看了很多教程,大多数使用的是tf.train.Saver(),这种方法还是太麻烦,没法直接像其他框架一样的保存成一个黑盒,你只要给输入就行。后来找了很多的博客,总算是找到了一种比较简单的方法,就是使用tf.saved_model.builder。
接下来以一个CNN训练mnist手写数字识别的例子介绍这种方法
模型保存
训练及保存的代码如下:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
batch_size = 100
n_batch = mnist.train.num_examples // batch_size
def weight_variable(shape):
return tf.Variable(tf.truncated_normal(shape,stddev=0.1))
def bias_vairable(shape):
return tf.Variable(tf.constant(0.1, shape=shape))
def conv2d(x,W):
return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME')
def max_pool_2x2(x):
return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
x = tf.placeholder(tf.float32,[None,784], name="input_x")
y = tf.placeholder(tf.float32,[None,10], name="input_y")
keep_prob = tf.placeholder(tf.float32)
x_image = tf.reshape(x,[-1,28,28,1