图像检索绘制PR曲线的两种方法(按需求)或者说生成指标的两种方法

根据不同哈希码的长度来判断最好的情况

每个哈希码长度下所有样本的平均精度和召回率

def pr_curve(qB, rB, query_label, retrieval_label):
	"只接受二进制码组成的0和1 或者 -1和1组成的二进制码,请把所有标签进行onehot编码"
    qB[qB==-1] = 0 ; rB[rB==-1] = 0 # 根据你哈希码长度来改变数值将
    num_query = qB.shape[0]
    num_bit = qB.shape[1]
    P = torch.zeros(num_query, num_bit+1)
    R = torch.zeros(num_query, num_bit+1)
    for i in range(num_query):
        gnd = (query_label[i].unsqueeze(0).mm(retrieval_label.t()) > 0).float().squeeze()
        print('gnd size is ' + str(gnd.shape))
        '''对于一个(num,4)与(4,num)大小的矩阵相乘最后得到是(num,num)大小的矩阵,他会展示标签之间是否相同'''
        # 这里的groudtruth标签大小是(retrieval_label,)
        tsum = torch.sum(gnd) # 一共有多少个相同的标签(事实上的,没有固定检索哈希码的长度)
        if tsum == 0:
            continue # 如果没有相似的也没必要继续进行了
        hamm = calc_hamming_dist(qB[i, :], rB) # 汉明距离
        print('hamming distance is '+str(hamm.shape))
        tmp = (hamm <= torch.arange(0, num_bit+1).reshape(-1, 1).float().to(hamm.device)).float()
        print('tmp size is '+str(tmp.shape))
        '''
        计算的汉明距离与特定哈希码长度做比较,如果超出了当前哈希码长度就是False
        解释:取一个查询集的哈希码,计算与所有检索集的汉明距离,然后使用tmp变量来存储特定哈希码长度下(有最大阈值的限制)的表现 
            tmp产生的列表是(num_bits + 1 , retrieval_label) 要么是True要么是False
        '''
        # 在指定哈希码长度下能检测到的最大阈值——哈希码的长度是多少 ( 也就是在特定哈希码长度的限制之下,能检测到的最多的检索集个数
        total = tmp.sum(dim=-1) # 降维1维列表,对应的每个哈希码长度对应的检索正确的整数
        print('total size is'+str(total.shape))
        total = total + (total == 0).float() * 0.1 # 将为0的元素变成0.1
        t = tmp * gnd # t是能被当前哈希码长度准确检测的到的实际个数 shape(num_bits,) —— 正确预测
        print('t size is '+str(t.shape))
        count = t.sum(dim=-1) # 所有可以被当前某个特定哈希码长度有效检索的图像的总数
        '''
        强调一下,这里的变量所代表的意义 : 
        1 , count 由 t 产生 代表当前哈希码长度下,能正确匹配到的正样本数量
        2 , t是能被当前哈希码长度正确检测的实际个数 , t 是比gnd小的 ,因为它更代表能被当前哈希码检索的这个前提条件还能识别到的groud truth标签
        3 , total 对应每个特定哈希码长度下所能匹配到的最大匹配数量,包含了正样本和负样本
        4 , tsum 本来就能就是正确的总数(正确答案)也就是所有的groud truth标签
        '''
        p = count / total # percision(精度) = (所有特定哈希码长度下可以有效检索的个数)/ (每个特定哈希码下正确的个数) —— 正确预测 / 总样本数 —— 模型的整体性能
        print('p shape is ' + str(p.shape))

        r = count / tsum # 正确预测 / 所有的正确样本
        print('r shape is ' + str(r.shape))
        P[i] = p # 第一个查询标签对应的Persion的指标
        R[i] = r # 以此类推
    print(f'P size is {str(P.shape)}')
    print(f'R size is {str(R.shape)}')
    mask = (P > 0).float().sum(dim=0) # 只考虑含有正样本的查询集
    mask = mask + (mask == 0).float() * 0.1
    P = P.sum(dim=0) / mask
    R = R.sum(dim=0) / mask
    '''
    num_bit表示不同的哈希码长度,范围为0到num_bit
    对于每个query,计算其在不同哈希码长度下的precision(P)和recall(R)
    计算完成后,利用mask只考虑非零值,求出所有query在每个哈希码长度下的平均P和平均R
    '''
	# 以下是可视化
    plt.plot(R, P, linestyle="-", marker='D', color='blue',label = 'DSH')
    plt.text(0.5, -0.1, '(a) PR curve @ 16bits', ha='center', va='center', fontsize=16, fontweight='bold', transform=plt.gca().transAxes)
    plt.grid(True)
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    # 调整图像的大小
    fig = plt.gcf()
    fig.set_size_inches(9,9)
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.legend()  # 加图例
    plt.show()

    return P, R

上述函数是用来计算每个哈希码长度下所有样本的平均精度和召回率

每个检索集样本在不同哈希码长度下的平均精度和召回率

这种方法的好处在与鲁棒性很强,可以接受任何类型的哈希码

def plot_pr_curve(query_binary, retrieval_binary, query_label, retrieval_label):
    """
    绘制检索评价的精度-召回率曲线(PR曲线)
    Args:
        query_binary (numpy.ndarray): 一个大小为 (num_query, num_bit) 的 numpy 数组,
            包含查询图像的二进制哈希码。
        retrieval_binary (numpy.ndarray): 一个大小为 (num_retrieval, num_bit) 的 numpy 数组,
            包含检索图像的二进制哈希码。
        query_label (numpy.ndarray): 一个大小为 (num_query,) 的 numpy 数组,
            包含查询图像的真实标签。
        retrieval_label (numpy.ndarray): 一个大小为 (num_retrieval,) 的 numpy 数组,
            包含检索图像的真实标签。

    Returns:
        None
    """
    # Convert the labels to int32 type to avoid indexing errors
    ## 将标签转换为 int32 类型,以避免索引错误
    query_label = query_label.astype(np.int32)
    retrieval_label = retrieval_label.astype(np.int32)

    # Calculate the Hamming distances between query and retrieval binary codes
    # 计算查询图像和检索图像之间的汉明距离
    hamming_dist = np.count_nonzero(query_binary[:, np.newaxis, :]
                                    != retrieval_binary[np.newaxis, :, :], axis=2)
    '''
    query_binary 和 retrieval_binary 是二进制向量,形状分别是(m, n) 和(p, n),n 为向量维度。
    query_binary[:, np.newaxis, :] 将查询向量形状改为(m, 1, n),插入第2维一个1维度。
    retrieval_binary[np.newaxis, :, :] 将索引向量形状改为(1, p, n),插入第1维一个1维度。
    然后使用 != 计算这两个三维矩阵对应位置不相等的元素个数,结果形状是(m, p)。
    np.count_nonzero() 计算Axis=2 不为0 的元素个数,即每个二维的(m, p) 对应位置的汉明距离。
    '''
    print(hamming_dist)
    # Sort the retrieval samples by ascending order of Hamming distance
    # 将检索图像按汉明距离排序
    idx = np.argsort(hamming_dist, axis=1) # 按照索引进行排序,汉明距离最小的索引派在前面

    # Initialize the precision-recall arrays
    # 初始化精度和召回率数组
    num_query = query_binary.shape[0]
    num_retrieval = retrieval_binary.shape[0]
    precision = np.zeros((num_query, num_retrieval))
    recall = np.zeros((num_query, num_retrieval))

    # Compute the precision-recall values for each query sample
    # 计算每个查询图像的精度和召回率
    for i in range(num_query):
        # Compute the ground-truth labels for the retrieval samples
        # 计算检索图像的真实标签
        gnd = (query_label[i] == retrieval_label[idx[i]])
        # Compute the cumulative sums of true positives and false positives
        # 计算真正例和假正例的累积和
        tp_cumsum = np.cumsum(gnd)
        fp_cumsum = np.cumsum(~gnd) # ~表示补码 , 取反
        # Compute the precision and recall values
        # 计算精度和召回率
        precision[i] = tp_cumsum / (tp_cumsum + fp_cumsum) # 正样本占总样本的比例
        recall[i] = tp_cumsum / np.count_nonzero(gnd) # 检测到的正样本占所有正样本的比例
    # Compute the mean precision and recall values over all queries
    # 计算所有查询图像的平均精度和召回率
    mean_precision = np.mean(precision, axis=0)
    mean_recall = np.mean(recall, axis=0)

    # Plot the precision-recall curve
    plt.plot(mean_recall, mean_precision, 'b-')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.show()

区别

代码上的区别

这两段代码计算PR曲线的方式存在以下主要区别:

计算汉明距离的方式不同:
第一段代码直接计算qB和rB的汉明距离。

第二段代码先将qB和rB升维,然后比较对应位置不相等的元素个数,得到汉明距离。

排序检索样本的方式不同:
第一段代码没有给出排序的细节。

第二段代码明确地按照汉明距离升序排列检索样本。

统计正负样本的方式不同:
第一段代码通过gnd矩阵记录每个查询与所有检索样本是否匹配。

第二段代码比较query和retrieval label是否相同,来判断是否匹配。

计算precision和recall的方式不同:
第一段代码直接计算整个过程得到的count和total,得到precision。

第二段代码逐步累加true positive和false positive,从而逐步计算precision和recall。

求平均precision和recall的方式相同:
都通过mask只考虑非零precision的查询,从而计算所有查询的平均precision和recall。
综上,两段代码的计算流程存在较大差异,主要体现在:

汉明距离的计算
检索样本的排序
匹配查询和检索的方式
每步计算precision和recall的方法
所以说两段代码计算PR曲线算法上存在较大差异,但最终目的相同,都得到平均precision和recall。

应用场景的区别

首段代码:

计算的是每个哈希码长度下,所有查询图像的平均精度和召回率。
产生的数据点个数是:哈希码长度 + 1。
适用于场景:评估不同哈希码长度下的 overall 性能。
第二段代码:

计算的是每个查询图像在不同哈希码长度下的精度和召回率曲线。
产生的数据点个数是:检索集样本个数。
然后取平均,得到每个哈希码长度下的平均精度和召回率。
适用于场景:分析不同查询图像在不同哈希码长度下的表现。
应用场景:

如果只关注不同哈希码长度下的 average PR 曲线,首段代码足够。

如果想分析不同查询图像在不同长度下的表现差异,则第二段代码更有用。可以发现不同类别图像适合的哈希码长度是否存在差异。

总的来说:

首段代码计算的是每个哈希码长度下所有查询图像的平均 PR 曲线,适用于评估不同哈希码长度的整体效果。

第二段代码计算的是每个查询图像在不同哈希码长度下的 PR 曲线,再求平均,适用于分析不同查询图像的哈希长度敏感度。

两种代码对应的数据点的意义

这两段代码绘制的PR曲线其实对应的数据点并不完全一样。

具体来说:

第一段代码计算的是:每个哈希码长度下,所有查询图像的平均精度和召回率。
这意味着:

对应每一个哈希码长度,它都产生1个数据点(P,R),表示在这个哈希码长度下,所有查询图像的平均精度和召回率。

最后共产生 num_bit + 1 个数据点,对应0~num_bit这num_bit+1个哈希码长度。

所以这num_bit+1个数据点反映的是:随着哈希码长度变化,整体模型性能的变化。

而第二段代码计算的是:

每个查询图像在不同哈希码长度下的精度和召回率曲线。
这意味着:

对应每一个查询图像和每一个哈希码长度,它都产生1个数据点(P,R)。

最终,会为每个查询图像产生相同的num_bit + 1个数据点。

然后再取平均,得到每个哈希码长度下所有查询图像的平均精度和召回率。

所以这个num_bit + 1个数据点反映的是:不同查询图像在各自哈希长度下的表现。

所以两者产生的数据点的意义是不同的:

第一种方式更关注模型整体性能随哈希码长度变化的趋势。
第二种方式更关注每个查询图像的表现,最后通过平均得到整体趋势。
总的来说,我的理解是两段代码绘制的PR曲线其实对应的数据点并不完全一样。他们所反映的信息不同,分别侧重点不同。

如果要绘制固定哈希码长度下检索集的表现,我们应该选用第二种代码。

如果需要绘制固定几个哈希码长度(16、32、64、128)下的 PR曲线。

第二种方式为每个查询图像计算精度和召回率,再取平均。这能更清晰地反映固定哈希码长度下的表现。

第二种方式对应的一个 PR 点,是基于检索集样本个数的。所以对于固定的哈希码长度,它能更充分利用数据,更准确地描绘 PR曲线。

而第一种方式对应的一个 PR 点,是基于哈希码长度的。对于固定的哈希码长度,它只产生1个 PR 点。所以信息量相对较少。

所以总的来说:

如果你要绘制固定几个哈希码长度下的 PR 曲线,第二种方式计算每个查询图像的 PR 曲线更合适。
因 为它能基于检索集样本整体情况,更准确地描绘固定哈希码长度下的 PR 曲线。
而第一种方式只产生1个 PR 点,信息量相对较少。
希望我的分析有助于你选择正片代码。

  • 要点总结:

你的需求是绘制固定几个哈希码长度下的 PR 曲线。
第二种方式绘制每个查询图像的 PR 曲线再取平均,对固定哈希码长度更合适。
因为它基于整个检索集样本,信息量更多,能描绘更准确的 PR曲线。
相比第一种方式只产生1个 PR 点,信息量相对较少。

第一种代码的适合的使用场景

第一种代码的可能适用场景主要有两个:

  • 评估不同哈希码长度下的整体效果
    通过计算每个哈希码长度下所有查询样本的平均精度和召回率,可以评估随着哈希码长度增加,模型整体性能的变化情况。

  • 比较不同模型在不同哈希码长度下的效果
    在保持其他参数不变的情况下,比如使用相同的数据集、相同的查询集和检索集,可以比较不同模型(如DSH、ITQ)在不同哈希码长度下的平均PR曲线,从而评估它们的整体效果。

  • 首段代码的特点是:

计算的是每个哈希码长度下所有查询样本的平均PR值
产生的数据点个数是哈希码长度+1
反映的是随着哈希码长度变化,整体模型性能的变化
所以主要适用于需要评估模型整体效果如何变化的场景。

  • 而第二段代码的特点是:

计算的是每个查询样本在不同哈希码长度下的PR值,再取平均
数据点个数是检索样本个数
反映的是不同查询样本在不同哈希码长度下的表现
更适用于分析不同查询样本在不同哈希码长度下的差异。

  • 11
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Dou_Huanmin

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值