tensorflow学习笔记11——开始run

1、sess.run()

# Create the session and run the graph
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(iterator.initializer)

用sess.run()有两种情况:
1.想要获取某个变量的时候:
2.执行某种操作的时候,这个操作不是一个变量,没有值。比如上图为了初始化全部变量。

2、讲一下DataSet的iterator
参考:https://blog.csdn.net/briblue/article/details/80962728

注意这段代码前面是还有这么一句的

iterator = train_dataset.make_initializable_iterator()

2、summary.merge_all()合并默认图形中的所有汇总.
merged_summaries是一个节点,必须先传入session.run()运行才能获得真正的汇总!

summary.FileWriter():将汇总结果写入事件(event file)

# Merge all the summary and write
summary_op = tf.compat.v1.summary.merge_all()
train_filewriter = tf.compat.v1.summary.FileWriter('train/', sess.graph)
saver=tf.compat.v1.train.Saver(max_to_keep=1)   #只保留最后一代模型,如果想保留全部,那就把max_to_keep=0

3、truepredictNum += np.sum(predictValue == testValue)计算预测正确的数量
accuracy1 = truepredictNum / 5000.0 #正确率

while (True):
    try:

        lossValue, lr, _ = sess.run([loss, learning_rate, opt_op])   #这里如果改成lossValue, lr = sess.run([loss, learning_rate])

        if step % 100 == 0:
            print("step %i: Learning_rate: %f Loss: %f" % (step, lr, lossValue))

        if step % 1000 == 0:
            saver.save(sess, 'model/my-model', global_step=step)
            truepredictNum = 0
            sess.run([testiterator.initializer, validiterator.initializer])
            accuracy1 = 0.0
            accuracy2 = 0.0

            while (True):
                try:
                    #在验证数据集上预测
                    predictValue, testValue = sess.run([validresult, validrecord_labels])
                    truepredictNum += np.sum(predictValue == testValue)
                except tf.errors.OutOfRangeError:
                    print("valid correct num: %i" % (truepredictNum))

                    accuracy1 = truepredictNum / 5000.0
                    break

            truepredictNum = 0

            while (True):
                try:
                    #在测试数据集上预测
                    predictValue, testValue = sess.run([testresult, testrecord_labels])
                    truepredictNum += np.sum(predictValue == testValue)
                except tf.errors.OutOfRangeError:
                    print("test correct num: %i" % (truepredictNum))

                    accuracy2 = truepredictNum / 10000.0
                    break

            summary = sess.run(summary_op, feed_dict={valid_accuracy: accuracy1, test_accuracy: accuracy2})
            train_filewriter.add_summary(summary, step)
        step += 1

    except tf.errors.OutOfRangeError:
        break

4、add_summary()将训练过程数据保存在filewriter指定的文件中
回头用tensorboard画图

valid_accuracy = tf.placeholder(tf.float32)
test_accuracy = tf.placeholder(tf.float32)
summary = sess.run(summary_op, feed_dict={valid_accuracy: accuracy1, test_accuracy: accuracy2})
train_filewriter.add_summary(summary, step)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值