本文介绍的内容适用于 tf.contrib.metrics.streaming_true_positives 以及类似的 streaming_XXX 的 metrics。
调用格式
import tensorflow as tf
a = tf.Variable([0,1,1,0], tf.bool)
b = tf.Variable([0,1,0,1], tf.bool)
tp, tp_update = tf.contrib.metrics.streaming_true_positives(predictions = a, labels = b)
返回参数
参数 tp 很好理解,是预测值 predictions 和金标准 labels 的 True Positive 值。
关于第二个参数 tp_update,很多同学就要问了,为什么要第二个参数,第一个参数 tp 不就能够完成计算 True Positive 的工作了嘛?理解 tp_update 就需要从 streaming 这个单词来理解了。
函数机制
我的理解,streaming类的 metrics 函数在调用的时候,其实是先更新 tp_update(这样的shadow variable),此时 tp_update 记录了之前所有的混淆矩阵,而 tp 仅仅是计算当前 tp_update 中的 True Positive。(有没有很像一种 streaming 操作)
调用注意点(非常重要)
一般需要先进行一次 sess.run 更新 tp_update,再 sess.run 计算当前 tp 值。
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer()) # tp_update 是保存在 tf.local_variables()中
sess.run(tp_update)
print(sess.run(tp))
验证和理解函数机制
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer()) # tp_update 是保存在 tf.local_variables()中
sess.run(tp_update) # 第一次更新 tp_update
print(sess.run(tp))
sess.run(tp_update) # 第二次更新 tp_update
print(sess.run(tp))
我们更新两次,第二次的 tp_update 应该会保留第一次 tp_update 的结果,并且再此基础上再叠加一次。所以得到的结果则是:
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer()) # tp_update 是保存在 tf.local_variables()中
sess.run(tp_update)
print(sess.run(tp))
sess.run(tf.local_variables_initializer()) # 清空 tp_update 等价于没有计算 tp_update
print(sess.run(tp))
可以发现清空了 tp_update 或者没有计算 tp_update,直接计算 tp 值并不会成功。