深度学习工程化--在脱离源码的情况下,用检查点文件进行二次训练

深度学习工程化–在脱离源码的情况下,用检查点文件进行二次训练

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

'''生成模拟数据'''
train_x=np.linspace(-1,1,100)
train_y=2*train_x+np.random.randn(*train_x.shape)*0.3
plt.plot(train_x,train_y,'ro',label="Original data")
plt.legend()
plt.show()

'''定义生成loss值可视化的函数'''
plotdata={"batchsize":[],"loss":[]}
def moving_average(a,w=10):
    if len(a)<w:
        return a[:]
    return [val if idx<w else sum(a[(idx-w):idx])/w for idx,val in enumerate(a)]
tf.reset_default_graph()
'''定义学习参数'''
training_epochs=58
display_step=2
with tf.Session() as sess:
    savedir=os.path.join(os.getcwd(),"data","log2")
    kpt=tf.train.latest_checkpoint(savedir)
    print("kpt:",kpt)
    new_saver=tf.train.import_meta_graph(kpt+'.meta')
    new_saver.restore(sess,kpt)

    print(tf.get_collection('optimizer'))
    optimizer=tf.get_collection('optimizer')[0]
    X=tf.get_collection('X')[0]
    Y=tf.get_collection('Y')[0]
    cost=tf.get_collection('cost')[0]
    result=tf.get_collection('result')[0]
    # print(tf.get_collection('global_step'))
    global_step=tf.get_collection('global_step')[0]
    # global_step = tf.Variable(0, name="global_step", trainable=False)
    # tf.add_to_collection('gloabal_step', global_step)
    '''节点恢复完成,可以继续训练'''
    while global_step.eval()/len(train_x)<training_epochs:
        step=int(global_step.eval()/len(train_y))
        for (x,y) in zip(train_x,train_y):
            sess.run(optimizer,feed_dict={X:x,Y:y})
        '''显示训练中的详细信息'''
        if step % display_step==0:
            loss=sess.run(cost,feed_dict={X:train_x,Y:train_y})
            print("Epoch:",step+1,"cost=",loss)
            if not (loss=="NA"):
                plotdata["batchsize"].append(global_step.eval())
                plotdata["loss"].append(loss)
            new_saver.save(sess,os.path.join(savedir,"lineramodel.ckpt"),global_step)

    print("Finished!")
    new_saver.save(sess,os.path.join(savedir,"lineramodel.ckpt"),global_step)
    print("cost=",sess.run(cost,feed_dict={X:train_x,Y:train_y}))

    plotdata["avgloss"]=moving_average(plotdata["loss"])
    plt.figure(1)
    plt.subplot(211)
    plt.plot(plotdata["batchsize"],plotdata["avgloss"],'b--')
    plt.xlabel('Minibatch number')
    plt.ylabel('Loss')
    plt.title('Minibatch run vs . Training loss')
    plt.show()

在这里插入图片描述
在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值