深度学习工程化–在脱离源码的情况下,用检查点文件进行二次训练
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
'''生成模拟数据'''
train_x=np.linspace(-1,1,100)
train_y=2*train_x+np.random.randn(*train_x.shape)*0.3
plt.plot(train_x,train_y,'ro',label="Original data")
plt.legend()
plt.show()
'''定义生成loss值可视化的函数'''
plotdata={"batchsize":[],"loss":[]}
def moving_average(a,w=10):
if len(a)<w:
return a[:]
return [val if idx<w else sum(a[(idx-w):idx])/w for idx,val in enumerate(a)]
tf.reset_default_graph()
'''定义学习参数'''
training_epochs=58
display_step=2
with tf.Session() as sess:
savedir=os.path.join(os.getcwd(),"data","log2")
kpt=tf.train.latest_checkpoint(savedir)
print("kpt:",kpt)
new_saver=tf.train.import_meta_graph(kpt+'.meta')
new_saver.restore(sess,kpt)
print(tf.get_collection('optimizer'))
optimizer=tf.get_collection('optimizer')[0]
X=tf.get_collection('X')[0]
Y=tf.get_collection('Y')[0]
cost=tf.get_collection('cost')[0]
result=tf.get_collection('result')[0]
# print(tf.get_collection('global_step'))
global_step=tf.get_collection('global_step')[0]
# global_step = tf.Variable(0, name="global_step", trainable=False)
# tf.add_to_collection('gloabal_step', global_step)
'''节点恢复完成,可以继续训练'''
while global_step.eval()/len(train_x)<training_epochs:
step=int(global_step.eval()/len(train_y))
for (x,y) in zip(train_x,train_y):
sess.run(optimizer,feed_dict={X:x,Y:y})
'''显示训练中的详细信息'''
if step % display_step==0:
loss=sess.run(cost,feed_dict={X:train_x,Y:train_y})
print("Epoch:",step+1,"cost=",loss)
if not (loss=="NA"):
plotdata["batchsize"].append(global_step.eval())
plotdata["loss"].append(loss)
new_saver.save(sess,os.path.join(savedir,"lineramodel.ckpt"),global_step)
print("Finished!")
new_saver.save(sess,os.path.join(savedir,"lineramodel.ckpt"),global_step)
print("cost=",sess.run(cost,feed_dict={X:train_x,Y:train_y}))
plotdata["avgloss"]=moving_average(plotdata["loss"])
plt.figure(1)
plt.subplot(211)
plt.plot(plotdata["batchsize"],plotdata["avgloss"],'b--')
plt.xlabel('Minibatch number')
plt.ylabel('Loss')
plt.title('Minibatch run vs . Training loss')
plt.show()