Tensorflow:训练模型保存与恢复

 

首先明确一点,tensorflow保存的是什么?

模型保存后产生四个文件,分别是:

|--models
|    |--checkpoint
|    |--.meta
|    |--.data
|    |--.index

其中.meta保存的是图的结构,checkpoint文件是个文本文件,里面记录了保存的最新的checkpoint文件以及其它checkpoint文件列表,.data和.index保存的是变量值。

即模型保存的是图的结构和变量值。


一 实例

以下是使用tensorflow实现简单的线性模型:

#生成样本数据
x = np.random.randn(10000,1)
y = 0.03*x+0.8

#定义模型参数
Weights = tf.Variable(tf.random_normal([1],seed=1,stddev=1),name='Weights')
bias = tf.Variable(tf.random_normal([1],seed=1,stddev=1),name='bias')


xx = tf.placeholder(tf.float32,shape=(None,1),name='xx')
yy = tf.placeholder(tf.float32,shape=(None,1),name='yy')

#线性模型
y_predict = tf.add(Weights*xx,bias,name='preds')

#损失函数
loss = tf.reduce_mean(tf.square(yy-y_predict))

#优化方法
optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

#批训练模型
batchsize = 20
samplesize = 100
start = 0
end = 0
with tf.Session() as sess:
    init_var = tf.global_variables_initializer()
    sess.run(init_var)
    print('before training, variable is %s,%s'%(sess.run(Weights),sess.run(bias)))
    for i in (range(5000)):
#         start = (i*batchsize)%100
        if end == samplesize:
            start = 0
        end = np.minimum(start+batchsize,samplesize)
#         try:
#             end = np.min(start+batchsize,samplesize)
#         except:
#             print(end)
        sess.run(optimizer,feed_dict={xx:x[start:end],yy:y[start:end]})
        if (i+1)%1000 == 0:
            print(sess.run(loss,feed_dict={xx:x[start:end],yy:y[start:end]}))
        start += batchsize
    print('after training, variable is %s,%s'%(sess.run(Weights),sess.run(bias)))

 

二 模型保存

通过以下程序可实现保存:

saver = tf.train.Saver()
saver.save(session,dir[,global_step])

save中第一个参数是session,第二个参数是模型保存的位置,第三个参数申明模型每迭代多少步保存一次。

保存一中的模型,并设置每1000步保存一次:

#生成样本数据
x = np.random.randn(10000,1)
y = 0.03*x+0.8

#定义模型参数
Weights = tf.Variable(tf.random_normal([1],seed=1,stddev=1),name='Weights')
bias = tf.Variable(tf.random_normal([1],seed=1,stddev=1),name='bias')


xx = tf.placeholder(tf.float32,shape=(None,1),name='xx')
yy = tf.placeholder(tf.float32,shape=(None,1),name='yy')

#线性模型
y_predict = tf.add(Weights*xx,bias,name='preds')

#损失函数
loss = tf.reduce_mean(tf.square(yy-y_predict))

#优化方法
optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

#批训练模型
batchsize = 20
samplesize = 100
start = 0
end = 0
with tf.Session() as sess:
    init_var = tf.global_variables_initializer()
    sess.run(init_var)
    saver = tf.train.Saver()

    print('before training, variable is %s,%s'%(sess.run(Weights),sess.run(bias)))
    for i in (range(5000)):
#         start = (i*batchsize)%100
        if end == samplesize:
            start = 0
        end = np.minimum(start+batchsize,samplesize)
#         try:
#             end = np.min(start+batchsize,samplesize)
#         except:
#             print(end)
        sess.run(optimizer,feed_dict={xx:x[start:end],yy:y[start:end]})

        #实现每1000步保存一次模型
        if (i+1)%1000 == 0:
            saver.save(sess,'models\ckp',1000)
            print(sess.run(loss,feed_dict={xx:x[start:end],yy:y[start:end]}))
        start += batchsize
    print('after training, variable is %s,%s'%(sess.run(Weights),sess.run(bias)))

以下代码实现了每1000步保存一次模型

if (i+1)%1000 == 0:
    saver.save(sess,'models\ckp',1000)

之所以这样做,是为了防止意外情况下(比如训练时突然断电)下次训练需要从头开始训练。

保存的目录结构如下

|--models
|    |--checkpoint
|    |--ckp-1000.meta
|    |--ckp-1000.data-00000-of-00001
|    |--ckp-1000.index

 

三 模型恢复

首先加载保存的meta文件

saver = tf.train.import_meta_graph(file_name)

恢复参数,依赖于session,dir表示模型保存的目录路径,此时所有张量的值都在session中

saver.restore(session,tf.train.latest_checkpoint(dir))

获取恢复的参数,varname表示恢复的参数名,因此建议所有的参数都加上name属性

graph = sess.graph #sess所打开的图,所有的结构都在这个图上
graph.get_tensor_by_name(varname)

以下给出回归模型的恢复,并利用训练好的模型进行预测:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('models\ckp-1000.meta')
    saver.restore(sess,tf.train.latest_checkpoint('models'))
    graph = tf.get_default_graph()

    #恢复传入值
    xx = graph.get_tensor_by_name('xx:0')


    #计算利用训练好的模型参数计算预测值
    preds = graph.get_tensor_by_name('preds:0')
    print('predict values:%s' % sess.run(preds,feed_dict={xx:x}))

 

 

 

 

  • 2
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值