1、sess.run()
# Create the session and run the graph
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(iterator.initializer)
用sess.run()有两种情况:
1.想要获取某个变量的时候:
2.执行某种操作的时候,这个操作不是一个变量,没有值。比如上图为了初始化全部变量。
2、讲一下DataSet的iterator
参考:https://blog.csdn.net/briblue/article/details/80962728
注意这段代码前面是还有这么一句的
iterator = train_dataset.make_initializable_iterator()
2、summary.merge_all()合并默认图形中的所有汇总.
merged_summaries是一个节点,必须先传入session.run()运行才能获得真正的汇总!
summary.FileWriter():将汇总结果写入事件(event file)
# Merge all the summary and write
summary_op = tf.compat.v1.summary.merge_all()
train_filewriter = tf.compat.v1.summary.FileWriter('train/', sess.graph)
saver=tf.compat.v1.train.Saver(max_to_keep=1) #只保留最后一代模型,如果想保留全部,那就把max_to_keep=0
3、truepredictNum += np.sum(predictValue == testValue)计算预测正确的数量
accuracy1 = truepredictNum / 5000.0 #正确率
while (True):
try:
lossValue, lr, _ = sess.run([loss, learning_rate, opt_op]) #这里如果改成lossValue, lr = sess.run([loss, learning_rate])
if step % 100 == 0:
print("step %i: Learning_rate: %f Loss: %f" % (step, lr, lossValue))
if step % 1000 == 0:
saver.save(sess, 'model/my-model', global_step=step)
truepredictNum = 0
sess.run([testiterator.initializer, validiterator.initializer])
accuracy1 = 0.0
accuracy2 = 0.0
while (True):
try:
#在验证数据集上预测
predictValue, testValue = sess.run([validresult, validrecord_labels])
truepredictNum += np.sum(predictValue == testValue)
except tf.errors.OutOfRangeError:
print("valid correct num: %i" % (truepredictNum))
accuracy1 = truepredictNum / 5000.0
break
truepredictNum = 0
while (True):
try:
#在测试数据集上预测
predictValue, testValue = sess.run([testresult, testrecord_labels])
truepredictNum += np.sum(predictValue == testValue)
except tf.errors.OutOfRangeError:
print("test correct num: %i" % (truepredictNum))
accuracy2 = truepredictNum / 10000.0
break
summary = sess.run(summary_op, feed_dict={valid_accuracy: accuracy1, test_accuracy: accuracy2})
train_filewriter.add_summary(summary, step)
step += 1
except tf.errors.OutOfRangeError:
break
4、add_summary()将训练过程数据保存在filewriter指定的文件中
回头用tensorboard画图
valid_accuracy = tf.placeholder(tf.float32)
test_accuracy = tf.placeholder(tf.float32)
summary = sess.run(summary_op, feed_dict={valid_accuracy: accuracy1, test_accuracy: accuracy2})
train_filewriter.add_summary(summary, step)