实例引入
实例1:演示with session的使用
import tensorflow as tf
a = tf.constant(3)
b = tf.constant(4)
with tf.Session() as sess:
print("相加:%i"%sess.run(a+b))
print("相乘:%i"%sess.run(a*b))
结果:
相加:7
相乘:12
实例2:演示注入机制
说明:使用注入机制,将具体的实参注入到相应的placeholder中。feed只在调用它的方法内有效,方法结束后feed就会消失。
import tensorflow as tf
a = tf.placeholder(tf.int16)
b = tf.placeholder(tf.int16)
add = tf.add(a,b)
mul = tf.multiply(a,b)
with tf.Session() as sess:
print("相加:%i" % sess.run(add,feed_dict={a:3,b:4}))
print("相乘:%i" % sess.run(mul,feed_dict={a:3,b:4}))
附:使用注入机制获取节点
print(sess.run([add,mul],feed_dict={a:3,b:4}))
如何保存和载入模型
1.保存模型
首先需要建立一个saver,然后在session中保存saver的save即可将模型保存起来。
#前面是各种构建模型的graph操作(sigmoid等等)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())#先对模型初始化
#然后将数据丢入模型进行训练
#....
#训练完之后,使用saver.save进行保存
saver.save(sess,"save_path/file_name")
2.载入模型
在session中通过调用saver的restore()函数,会从指定的路径找到模型文件,并且覆盖到相关参数中。
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess,"save_path/file_name")
检查点
概述:保存模型并不限于训练之后,在训练中也需要保存,因为TensorFlow训练模型时难免会出现中断的情况。我们自然希望能够将辛苦得到的中间参数保留下来,否则下次又要重新开始。这种在训练中保存模型,习惯上称之为保存检查点。
创建saver时多添加一个参数,max_to_keep=1,表明最多只保存一个检查点文件。
saver = tf.train.Saver(max_to_keep=1)
在保存时使用了如下代码传入了迭代次数。
saver.save = (sess,savedir+"filename.cpkt",global_step=epoch)
在载入时,同样也要指定迭代次数
saver.restore(sess2,savedir+"filename.cpkt-"+str(load_epoch))
另外还可以用另外一个方法快速获取到检查点文件
kpt = tf.train.latest_checkpoint(savedir)
if kpt!=None:
saver.restore(sess,kpt)
更简洁地保存检查点
概述:使用tf.train.MonitoredTrainingSession函数。该函数可以直接实现保存及载入检查点模型的文件。该方法是通过按照训练时间来保存的。通过指定save_checkpoint_secs参数的具体秒数,来设置每训练多久保存一次检查点。
import tensorflow as tf
tf.reset_default_graph()
global_step = tf.train.get_or_create_global_step()
step = tf.assign_add(global_step,1)
with tf.train.MonitoredTrainingSession(checkpoint_dir='log/checkpoints',
save_checkpoint_secs=2) as sess:
print(sess.run([global_step]))
#启用死循环,当sess不结束时就不停止
while not sess.should_stop():
i = sess.run(step)
print(i)
注意:
1.通过设置save_checkpoint_secs,更适合使用大型数据集来训练复杂模型的情况。
2.使用该方法,必须定义global_step变量,否则会报错。
TensorBoard可视化
API介绍
函数 | 说明 |
---|---|
tf.summary.scalar(tag,values,collections=None,name=None) | 标量数据汇总,输出protobuf |
tf.summary.histogram(tag,values,collections=None,name=None) | 记录变量var的直方图,输出带直方图的汇总的protobuf |
tf.summary.image(tag,values,max_image=3,collections=None,name=None) | 图像数据汇总,输出protobuf |
tf.summary.merge(inputs,collections=None,name=None) | 合并所有的汇总日志 |
tf.summary.FileWriter | 合并所有的汇总日志 |
Class SummaryWriter:add_summary(),add_graph() | 将protobuf写入文件的类 |
实例:线性回归的tensorboard可视化
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
train_X = np.linspace(-1,1,100)
train_Y = 2*train_X + np.random.randn(*train_X.shape)*0.3
#创建模型
X = tf.placeholder("float")
Y = tf.placeholder("float")
#模型参数
W = tf.Variable(tf.random_normal([1]),name='weight')
b = tf.Variable(tf.zeros([1]),name='bias')
#前向结构
z = tf.multiply(X,W)+b
#将预测值以直方图形式显示
tf.summary.histogram('z',z)
#反向优化
cost = tf.reduce_mean(tf.square(Y-z))
#将损失以标量形式显示
tf.summary.scalar('loss_function',cost)
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
#训练模型
init = tf.global_variables_initializer()
#定义参数
training_epochs = 20
display_step = 2
#创建saver
saver = tf.train.Saver()
savedir = "tensorflow的学习/"
#启动session
with tf.Session() as sess:
sess.run(init)
plotdata={"batchsize":[],"loss":[]}
#合并所有summary
merged_summary_op = tf.summary.merge_all()
#创建summary_writer
summary_writer = tf.summary.FileWriter('log/mnist_with_summaries',sess.graph)
#向模型中输入数据
for epoch in range(training_epochs):
for(x,y) in zip(train_X,train_Y):
sess.run(optimizer,feed_dict={X:x,Y:y})
#显示训练中的详细信息
if epoch % display_step == 0:
loss = sess.run(cost,feed_dict={X:train_X,Y:train_Y})
print("Epoch:",epoch+1,"cost=",loss,"W=",sess.run(W),"b=",sess.run(b))
if not(loss=="NA"):
plotdata["batchsize"].append(epoch)
plotdata["loss"].append(loss)
#生成summary
summary_str = sess.run(merged_summary_op,feed_dict={X:x,Y:y});
summary_writer.add_summary(summary_str,epoch);
print("Finished!")
saver.save(sess,savedir+"model.cpkt")
print("cost=",sess.run(cost,feed_dict={X:train_X,Y:train_Y}),
"W=",sess.run(W),"b=",sess.run(b))
#模型可视化
# plt.plot(train_X,train_Y,'ro',label='Original data')
# plt.plot(train_X,sess.run(W)*train_X+sess.run(b),label='Fittedline')
# plt.legend()
# plt.show()
with tf.Session() as sess2:
#使用模型
saver.restore(sess2,savedir+"model.cpkt")
print("x=0.2,z=",sess2.run(z,feed_dict={X:0.2}))
说明:生成文件后,输入cmd,来到summary日志的上级路径下,输入如下命令:
tensorboard --logdir 存放文件的路径(xxxxxx/mnist_with_summaries)