Tensorflow的使用——模型的保存与加载

模型的存储和加载

训练好一个神经网络后,我们希望能够将其应用在预测数据上。那么,如何把模型存储起来呢?同时,对于一个已经存储起来的模型,在将其应用在预测数据上时又如何加载呢?

Tensorflow的API提供了以下两种方式来存储和加载模型

(1)生成检查点文件,扩展为一般为.ckpt,通过在tf.train.Saver对象上调用Saver.save()生成。它包含权重和其他在程序中定义的变量,不包含图结构。如果需要在另一个程序中使用,需要重新创建图形结构,并告诉Tensorflow如何处理这些权重。

(2)生成图协议文件,这是一个二进制文件,扩展名一般为.pb,用tf.train.write_graph()保存,只包含图形结构,不包含权重,然后使用tf.import_graph_def()来加载图形。

1.模型的存储与加载

模型存储主要是建立一个tf.train.Saver()来保存变量,并且指定保存的位置,一般模型的扩展名为.ckpt。

存储模型

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os

# 1.加载数据
mnist = input_data.read_data_sets('./data/MNIST_data/', one_hot=True)
tr_X, tr_Y, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels

# 2.设置占位符
X = tf.placeholder('float', [None, 784])
y = tf.placeholder('float', [None, 10])

# 3.定义权重函数
def init_weights(shape):
    return tf.Variable(tf.random_normal(shape=shape, stddev=0.01))

# 4.初始化权重参数
w_h = init_weights([784, 625])
w_h2 = init_weights([625, 625])
w_o = init_weights([625, 10])

# 5.定义网络模型
def model(X, w_h, w_h2, w_o, p_keep_input, p_keep_hidden):
    # *1.第一个全连接层
    X = tf.nn.dropout(X, p_keep_input)
    h = tf.nn.relu(tf.matmul(X, w_h))
    
    h = tf.nn.dropout(h, p_keep_hidden)
    # *2.第二个全连接层
    h2 = tf.nn.relu(tf.matmul(h, w_h2))
    h2 = tf.nn.dropout(h2, p_keep_hidden)
    
    return tf.matmul(h2, w_o)

# 6.生成网络模型,得到预测值
p_keep_input  = tf.placeholder('float')
p_keep_hidden  = tf.placeholder('float')
py_x = model(X, w_h, w_h2, w_o, p_keep_input, p_keep_hidden)

# 7.定义损失函数
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=py_x, labels=y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op = tf.arg_max(py_x, 1)

ckpt_dir = './data/ckpt_dir'
if not os.path.exists(ckpt_dir):    # 判断文件夹是否存在
    os.makedirs(ckpt_dir) # 创建新的文件夹

saver = tf.train.Saver()   # 创建节点保存器
# non_storable_variable = tf.Variable(777) # 可有可无
global_step = tf.Variable(0, name='global_step', trainable=False)
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())   # 初始化之前声明的变量
    start = global_step.eval()    # 获取global_step中的值,也就是初始值0.
    print('Start from:', start)
    for i in range(start, 100):
        for start, end in zip(range(0, len(tr_X), 128), range(128, len(tr_X)+1, 128)):
            sess.run(train_op, feed_dict={X:tr_X[start:end], y:tr_Y[start:end], p_keep_input:0.8, p_keep_hidden:0.5})
        global_step.assign(i).eval()
        saver.save(sess, ckpt_dir+'/model.ckpt', global_step=global_step)

 

模型的加载

# 模型的加载
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    ckpt = tf.train.get_checkpoint_state(ckpt_dir)
    if ckpt and ckpt.model_checkpoint_path:
        print(ckpt.model_checkpoint_path)
        saver.restore(sess, ckpt.model_checkpoint_path)

2.图的存储与加载

当仅保存图模型时,才将图写入二进制协议文件中。

# 图的储存
# 当且保存图模型的时候,才将图写入二进制文件中
v = tf.Variable(0, name='chaucer')
sess = tf.Session()
tf.train.write_graph(sess.graph_def, './data/tmp/tfmodel','train_mnist.pbtxt')

图的读取,当读取时,从协议文件中都取出来。

# 图的读取
# 当读取时,从协议文件中都取出来

with tf.Session() as sess:
    with open('./data/tmp/tfmodel/train_mnist.pbtxt', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        _sess.graph.as_default()
        tf.import_graph_def(graph_def, name='tfgraph')

参考:

Tensorflow技术解析与实战 李嘉璇

 

 

 

 

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值