使用tf.metrics.xx
[主要参考](https://zhuanlan.zhihu.com/p/42438077)
主要参考这个来理解tf.metrics返回的两个op都是什么意思,以及什么时候使用sess.run(tf.local_variables_initializer())
来进行batch的计算和整个数据集评估函数的计算。
返回的op
最简单粗暴来说,都用update_op 就好了。因为第二个执行后一个作用是更新变量,另外会同时返回一个结果,对于tf.metrics.accuracy,就是更新变量后实时计算的accuracy。
以accuracy为例,返回的两个op分别为:
- accuracy:用来计算accuracy的值的op
- update_op:用来更新截止上一次
sess.run(tf.local_variables_initializer())
后的总样本数total
和预测对的样本数count
,并返回最新accuracy的op。
使用tf.metrics进行train和valid的分别计算
文章里教了如何整体算和batch算,但是没写如何在同一个代码里实现train和valid的metrics计算。这里写一下:
其中tf_metrics为这里下载获得的用于多分类的评估函数
accuracy,accuracy_op = tf.metrics.accuracy(y_i,pre)
precision,precision_op = tf_metrics.precision(
y_i, pre, num_classes, pos_indices, average=average)
tf.summary.scalar('precision',precision_op) ##只要第二个op就可以了
tf.summary.scalar('accuracy',accuracy_op)
summary_op = tf.summary.merge_all()
with tf.Session() as sess:
sess.run(tf.initialize_all_variables(