深度学习工程化–在线性回归模型中,向检查点文件中添加指定节点
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
'''生成模拟数据'''
train_x=np.linspace(-1,1,100)
train_y=2*train_x+np.random.randn(*train_x.shape)*0.3
plt.plot(train_x,train_y,'ro',label="Original data")
plt.legend()
plt.show()
tf.reset_default_graph()
'''构建网络模型'''
X=tf.placeholder("float")
Y=tf.placeholder("float")
W=tf.Variable(tf.random_normal([1]),name='weight')
b=tf.Variable(tf.zeros([1]),name="bias")
z=tf.multiply(X,W)+b
global_step=tf.Variable(0,name="global_step",trainable=False)
'''反向优化'''
cost=tf.reduce_mean(tf.square(Y-z))
learning_rate=0.01
optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(cost,global_step)
'''定义初始化所有变量'''
init=tf.global_variables_initializer()
training_epochs=20
display_step=2
savedir=os.path.join(os.getcwd(),"data","log2")
'''保存一个检查点文件'''
saver=tf.train.Saver(tf.global_variables(),max_to_keep=1)
tf.add_to_collection('optimizer',optimizer)
tf.add_to_collection('X',X)
tf.add_to_collection('Y',Y)
tf.add_to_collection('cost',cost)
tf.add_to_collection('result',z)
tf.add_to_collection('gloabal_step',global_step)
plotdata={"batchsize":[],"loss":[]}
def moving_average(a,w=10):
if len(a)<w:
return a[:]
return [val if idx<w else sum(a[(idx-w):idx])/w for idx,val in enumerate(a)]
with tf.Session() as sess:
sess.run(init)
kpt=tf.train.latest_checkpoint(savedir)
if kpt!=None:
saver.restore(sess,kpt)
while global_step.eval()/len(train_x)<training_epochs:
step=int(global_step.eval()/len(train_x))
for (x,y) in zip(train_x,train_y):
sess.run(optimizer,feed_dict={X:x,Y:y})
'''显示训练中的详细信息'''
if step % display_step==0:
loss=sess.run(cost,feed_dict={X:train_x,Y:train_y})
print("Epoch:",step+1,"cost=",loss,"W=",sess.run(W),"b=",sess.run(b))
if not (loss=="NA"):
plotdata["batchsize"].append(global_step.eval())
plotdata["loss"].append(loss)
saver.save(sess,savedir+"linermodel.cpkt",global_step)
print("Finished!")
saver.save(sess,savedir+"linermodel.cpkt",global_step)
print("cost=",sess.run(cost,feed_dict={X:train_x,Y:train_y}),"W=",sess.run(W),"b=",sess.run(b))
'''显示模型'''
plt.plot(train_x,train_y,'ro',label="Original data")
plt.plot(train_x,sess.run(W)*train_x+sess.run(b),label="Fittedline")
plt.legend()
plt.show()
plotdata["avgloss"]=moving_average(plotdata["loss"])
print(plotdata["batchsize"])
print(plotdata["avgloss"])
plt.figure(1)
plt.subplot(211)
plt.plot(plotdata["batchsize"],plotdata["avgloss"],'b--')
plt.xlabel('Minibatch number')
plt.ylabel("Loss")
plt.title("Minibatch run vs . Training loss")
plt.show()