pytorch where纵横不对称坑

图文匹配中, S ∈ [ 0 , 1 ] n × n S\in [0,1]^{n\times n} S[0,1]n×n 是一个相似度矩阵,即 S i j S_{ij} Sij 是第 i 幅图 I i I_i Ii 与第 j 条句子 T j T_j Tj 的相似度,而 ( I i , T i ) i = 1 n (I_i, T_i)_{i=1}^n (Ii,Ti)i=1n 是 ground-truth pair。检索文本(text retrieval)要求以 I i I_i Ii 为 quey 时 T i T_i Ti 排第几(即其 ranking);反过来检索图像(image retrieval)要求 T i T_i Ti 为 query 时 I i I_i Ii 的 ranking。

Text retrieval(逐行检索)可如此写:

import torch

# S[i][j] = sim(image_i, text_j)
S = torch.randperm(36).view(6, 6)
print("similarities:", S)

# text retrieval 求 ranking
asc_tx = S.argsort(1, descending=True)
print("asc_tx:", asc_tx)
sid = torch.arange(S.size(0)) # [n]
print("sample id:", sid)
rank_tx = torch.where(sid.unsqueeze(1) == asc_tx) # [n, 1]
print("rank_tx:", rank_tx)

# 两种判断是否为 top-1 的写法对拍,结果一致
print("top-1")
print(rank_tx[1] < 1)
print(sid == S.argmax(1))

结果:

similarities:
tensor([[ 5, 25, 28, 15, 29, 19],
        [ 3, 13, 21,  1,  0, 16],
        [ 9, 31, 12, 18, 32, 14],
        [17,  2, 26,  4, 10,  7],
        [ 8, 23, 11, 35, 34, 20],
        [24,  6, 27, 30, 22, 33]])

asc_tx:
tensor([[4, 2, 1, 5, 3, 0],
        [2, 5, 1, 0, 3, 4],
        [4, 1, 3, 5, 2, 0],
        [2, 0, 4, 5, 3, 1],
        [3, 4, 1, 5, 2, 0],
        [5, 3, 2, 0, 4, 1]])

sample id: tensor([0, 1, 2, 3, 4, 5])
rank_tx: (tensor([0, 1, 2, 3, 4, 5]), tensor([5, 2, 4, 4, 1, 0]))  # <- 第一个向量是行序号,升序,没问题

top-1
tensor([False, False, False, False, False,  True])  # 一致
tensor([False, False, False, False, False,  True])  # 一致

这种写法的思路是用 torch.argsort 按行排序,然后用 torch.where 求每一行序号等于 sample ID 的位置,即为 ranking。torch.where 返回的结果 rank_tx 是两个向量,第一个是行座标,第二个是列座标,由于 text retrieval 是逐行检索,所以列座标是 ranking。从结果看,这种写法没问题。

但当用同样思路写 image retrieval(逐列检索)时,出问题了:

import torch

# S[i][j] = sim(image_i, text_j)
S = torch.randperm(36).view(6, 6)
print("similarities:", S)

sid = torch.arange(S.size(0)) # [n]
# print("sample id:", sid)

# image retrieval
asc_im = S.argsort(0, descending=True) # 排序轴换成 0
print("asc_im:", asc_im)
rank_im = torch.where(sid.unsqueeze(0) == asc_im) # [1, n]
print("rank_im:", rank_im) # 不对劲

# 两种 top-1 写法对拍不过
print("top-1")
print(rank_im[0] < 1) # 取第一个个向量,即行位置
print(sid == S.argmax(0))

结果:

similarities:
tensor([[19, 16,  1, 15, 24, 28],
        [33, 21,  8,  3,  2, 34],
        [14, 25,  7, 32, 17,  0],
        [30,  6, 26, 11, 27,  4],
        [31, 20, 29, 22, 35, 23],
        [12, 13,  5, 18, 10,  9]])

asc_im:
tensor([[1, 2, 4, 2, 4, 1],
        [4, 1, 3, 4, 3, 0],
        [3, 4, 1, 5, 0, 4],
        [0, 0, 2, 0, 2, 5],
        [2, 5, 5, 3, 5, 3],
        [5, 3, 0, 1, 1, 2]])

rank_im: (tensor([0, 1, 3, 3, 3, 4]), tensor([4, 1, 0, 2, 5, 3]))  # <- 第二个向量是列序号,是乱序!

top-1
tensor([ True, False, False, False, False, False])  # 不一致
tensor([False, False, False, False,  True, False])  # 不一致

这种 image retrieval 的写法是按照前面 text retrieval 的写法对称改过来的:

  • argsort 排序轴 0 -> 1(按行 -> 按列)。这步没问题;
  • sid.unsqueeze(1) -> sid.unsqueeze(0),即换成求每列序号等于 sample ID 的位置。这步的结果就不对了,前面 rank_tx 的第一个向量是升序的行序号,而 rank_im 的第二个向量却是乱序的列序号!

这个现象就是题目所谓 torch.where 纵横不对称。从 rank_im 来看,torch.where 的策略是行主序搜索,即搜完一行再一行,保证其结果 rank_im 的第一个向量是非降的,rank_tx 也满足这点。

一个解决办法是:转置 argsort 结果,然后照抄逐行检索的写法:

import torch

# S[i][j] = sim(image_i, text_j)
S = torch.randperm(36).view(6, 6)
print("similarities:", S)

# image retrieval, corrected
asc_im = S.argsort(0, descending=True)
# rank_im = torch.where(sid.unsqueeze(0) == asc_im) # 出事写法
rank_im2 = torch.where(sid.unsqueeze(1) == asc_im.T) # 转置 argsort,按逐行检索写法来
print("rank_im2:", rank_im2)

print("top-1")
# print(rank_im[0] < 1)
print(rank_im2[1] < 1) # 还是用第二个向量
print(sid == S.argmax(0))

结果:

similarities:
tensor([[16, 17, 11, 10, 23, 33],
        [13, 15, 27, 34,  7, 24],
        [26, 29, 20,  6, 18, 31],
        [ 0, 32, 14, 12, 25, 35],
        [ 1,  2,  4,  9, 19, 22],
        [28, 30,  3,  5,  8, 21]])

rank_im2: (tensor([0, 1, 2, 3, 4, 5]), tensor([2, 4, 1, 1, 2, 5]))

top-1
tensor([False, False, False, False, False, False])  # 一致
tensor([False, False, False, False, False, False])  # 一致
  • 3
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值