跨模态搜索——MAP计算过程

跨模态搜索——MAP计算过程

1.汉明距计算:

公式:

解释:不同的位数越多,汉明距越大

代码:

def calc_hamming_dist(B1, B2):
    q = B2.shape[1]   #哈希码位数
    if len(B1.shape) < 2:
        B1 = B1.unsqueeze(0)
    distH = 0.5 * (q - B1.mm(B2.t()))  #计算汉明码距离 公式4
    return distH

输入:

B1: [1, -1, 1, 1]

B2:[1, -1, 1, -1], [-1, -1, 1, -1], [-1, -1, 1, -1], [1, 1, -1, -1], [-1, 1, -1, -1], [1, 1, -1, 1]]

输出:[1., 2., 2., 3., 4., 2.]

2. MAP计算

假设原始相关向量为:[1,1,0,0,1,0](1表示与它相关,0表示与它不相关),那么一共三个相关的。按照预测的哈希进行汉明码排序后,为[0,0,1,1,1,0]。则MAP=(1/3+2/4+3/5)/3=0.4778。再例,原始相关向量为:[1,1,0,0,1,0]。按照预测的哈希进行汉明码排序后,为[1,0,1,1,0]。则MAP=(1/1+2/3+3/4)/3=0.9167。

如果k为2,原始相关向量为[1,1,0,0,1,0],重排后为[1,0,1,1,0],则MAP=(1/1+2/3)/2=0.8333。

def calc_map_k(qB, rB, query_label, retrieval_label, k=None):
    # qB:查询集  范围{-1,+1}
    # rB:检索集  范围{-1,+1}
    # query_label: 查询标签
    # retrieval_label: 检索标签
    num_query = query_label.shape[0]  #查询个数
    map = 0.
    if k is None:
        k = retrieval_label.shape[0]  #如果不指定k,k将是全部检索个数。对于flickr25k数据集,即18015
    for iter in range(num_query):
        #每个查询标签乘以检索标签的转置,只要有相同标签,该位置就是1
        gnd = (query_label[iter].unsqueeze(0).mm(retrieval_label.t()) > 0).type(torch.float).squeeze()
        tsum = torch.sum(gnd)   #真实相关的数据个数
        print("相关个数:",tsum)
        if tsum == 0:
            continue
        hamm = calc_hamming_dist(qB[iter, :], rB)
        _, ind = torch.sort(hamm) #ind :已排序的汉明距,在未排序中的位置
        ind.squeeze_()
        print("原始 gnd:",gnd)
        print("ind    :", ind)
        gnd = gnd[ind]  #按照预测的顺序重排
        print("重排后gnd:", gnd)
        total = min(k, int(tsum))  #取k和tsum的最小值,这句应该没啥用
        #如果有三个相关的,则count是[1,2,3]
        count = torch.arange(1, total + 1).type(torch.float).to(gnd.device)
        #取出重排后非0元素的位置
        tindex = torch.nonzero(gnd)[:total].squeeze().type(torch.float) + 1.0
        print("count:",count)
        print("tindex:",tindex)
        map += torch.mean(count / tindex)
        print("map:",map)
    map = map / num_query
    return map

输出:

相关个数:tensor(3.)
原始 gnd: tensor([1., 1., 0., 0., 1., 0.])
ind    : tensor([5, 3, 4, 0, 1, 2])
重排后gnd: tensor([0., 0., 1., 1., 1., 0.])
count: tensor([1., 2., 3.])
tindex: tensor([3., 4., 5.])
map: tensor(0.4778)

特点:相关性,只要有相同标签的就算。如果有n条相关的数据,你只需要把这n条全部找出来,这n条数据内部的顺序不考虑。

 

 

 

 

  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值