最近在实现DRRM, 计算直方图,官方的api, tf.histogram_fixed_width不支持批量的直方图计算
在网上搜了一些方法,https://stackoverflow.com/questions/41764199/row-wise-histogram/,但是运算速度都比较慢,我自己想了一种,根据连续值转整型的方法,来进行批量计算直方图。
- tensor: 是输入的tensor
- nbins: 目标分桶个数
- row_num, col_num:输入的行列长度
直接输入 tensor 下面的函数就可以返回结果为 (row_num, nbins)的批量直方图结果.
输入的tensor会乘以nbins, 然后取整,其实就是对输入进行了分桶了。
其原理大致是生成一个 [nbins, row_num, nbins] 的mask矩阵,针对第0维度,每个位置的值,就是这个位置的序号。比如[0, row_num, nbins] 的所有值为0, [1, row_num, nbins] 的所有值为1,… etc
分桶后的值与mask分别进行==对比,然后加和,就可以得到每个桶内的值了,即直方图
def histogram_v3_nomask(tensor, nbins, row_num