基于tensorflow统计topK准确率

问题描述

简单介绍下思路:
我们有tf.nn.top_k可以直接用,能够按照概率由高到低返回前k个索引,注意,tf.nn.top_k实际上是返回两个参数的,第一个是前k大的值,第二个是前k大的值所在的索引,这里我们只需要后者。
我们得到了每个样本预测出的前k个类别,接下来我们想知道对每个样本来说,其真实类别是否存在于这k个预测类别之中,最后我们统计k个预测类别中存在有真实类别的样本数量,除以样本总数,得到topK准确率。
我们将以上转化为一个数值运算问题,对每个样本来说,若存在,输出1;不存在,输出0。然后计算由所有样本的0、1输出组成的array的平均值。

代码

所以计算top的时候,实际上是用tensorflow的API做了二次开发,代码如下:

def caculate_topK(indices, k):
	a = indices - tf.reshape(self.b_labels, (batch_size, 1))
	b = tf.equal(a, tf.zeros(shape=(batch_size, k), dtype=tf.int32))
	return tf.reduce_mean(tf.reduce_sum(tf.cast(b, tf.float32), axis=1), name='top_{}'.format(k))
_, self.top_5_indices = tf.nn.top_k(y, k=5, name='top_5_indices')
self.acc_top_5 = caculate_topK(self.top_5_indices, 5)

这里动用了tf.equal、tf.cast这俩不常用的方法。

思路解析

思路其实还是很有趣的,下面是具体思路:

  1. 前k个预测类别分别同真实类别做减法,如果类别相同,返回0;若类别不同,则返回非0的差值。传播至该batch中所有样本,得到张量a,其shape=(batch_size, k)。
  2. 接下来我们需要比较a中0的个数,tensor做条件判断很麻烦,所以我们基于布尔运算迂回一下,基于tf.equal,如果一个样本的a值是[1,0,-3,4,5],那我们利用tf.equal,使其同[0*5]对比,得到对应的b值为[False, True, False, False, False];
  3. 调用tf.cast将布尔矩阵转换为01矩阵,这个方法还是比较玄幻的
  4. 然后求和计算平均值,大功告成。

Good Luck!

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值