深度学习工程化--在线性回归模型中,向检查点文件中添加指定节点

深度学习工程化–在线性回归模型中,向检查点文件中添加指定节点

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

'''生成模拟数据'''
train_x=np.linspace(-1,1,100)
train_y=2*train_x+np.random.randn(*train_x.shape)*0.3
plt.plot(train_x,train_y,'ro',label="Original data")
plt.legend()
plt.show()

tf.reset_default_graph()
'''构建网络模型'''
X=tf.placeholder("float")
Y=tf.placeholder("float")
W=tf.Variable(tf.random_normal([1]),name='weight')
b=tf.Variable(tf.zeros([1]),name="bias")
z=tf.multiply(X,W)+b
global_step=tf.Variable(0,name="global_step",trainable=False)
'''反向优化'''
cost=tf.reduce_mean(tf.square(Y-z))
learning_rate=0.01
optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(cost,global_step)
'''定义初始化所有变量'''
init=tf.global_variables_initializer()
training_epochs=20
display_step=2
savedir=os.path.join(os.getcwd(),"data","log2")
'''保存一个检查点文件'''
saver=tf.train.Saver(tf.global_variables(),max_to_keep=1)
tf.add_to_collection('optimizer',optimizer)
tf.add_to_collection('X',X)
tf.add_to_collection('Y',Y)
tf.add_to_collection('cost',cost)
tf.add_to_collection('result',z)
tf.add_to_collection('gloabal_step',global_step)
plotdata={"batchsize":[],"loss":[]}
def moving_average(a,w=10):
    if len(a)<w:
        return a[:]
    return [val if idx<w else sum(a[(idx-w):idx])/w for idx,val in enumerate(a)]
with tf.Session() as sess:
    sess.run(init)
    kpt=tf.train.latest_checkpoint(savedir)
    if kpt!=None:
        saver.restore(sess,kpt)

    while global_step.eval()/len(train_x)<training_epochs:
        step=int(global_step.eval()/len(train_x))
        for (x,y) in zip(train_x,train_y):
            sess.run(optimizer,feed_dict={X:x,Y:y})

        '''显示训练中的详细信息'''
        if step % display_step==0:
            loss=sess.run(cost,feed_dict={X:train_x,Y:train_y})
            print("Epoch:",step+1,"cost=",loss,"W=",sess.run(W),"b=",sess.run(b))
            if not (loss=="NA"):
                plotdata["batchsize"].append(global_step.eval())
                plotdata["loss"].append(loss)
            saver.save(sess,savedir+"linermodel.cpkt",global_step)
        print("Finished!")
        saver.save(sess,savedir+"linermodel.cpkt",global_step)
        print("cost=",sess.run(cost,feed_dict={X:train_x,Y:train_y}),"W=",sess.run(W),"b=",sess.run(b))

        '''显示模型'''
        plt.plot(train_x,train_y,'ro',label="Original data")
        plt.plot(train_x,sess.run(W)*train_x+sess.run(b),label="Fittedline")
        plt.legend()
        plt.show()

        plotdata["avgloss"]=moving_average(plotdata["loss"])
        print(plotdata["batchsize"])
        print(plotdata["avgloss"])

        plt.figure(1)
        plt.subplot(211)
        plt.plot(plotdata["batchsize"],plotdata["avgloss"],'b--')
        plt.xlabel('Minibatch number')
        plt.ylabel("Loss")
        plt.title("Minibatch run vs . Training loss")
        plt.show()

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值