torch.nn.functional.cosine_similarity计算相似度

# 目的是求a、b之间的余弦相似度。一般是[bs,h]形状。以[2,4]为例。


a = torch.rand(2, 4)
b = torch.rand(2, 4)
print('初始输入\n', a, '\n', b)

ab = torch.matmul(a, b.t())
print('ab相乘\n', ab)

aa = a*a
print('a*a\n',aa)
aa = aa.sum(1, keepdim=True)
print('求和aa')
aa = aa ** 0.5
print('开方aa\n', aa)

bb = b * b
print('b*b\n',bb)
bb = bb.sum(1, keepdim=True)
print('求和bb')
bb = bb ** 0.5
print('开方bb\n', bb)

   
c = ab/aa/bb
print('自己算的cosi相似度\n', c)

d = torch.nn.functional.cosine_similarity(a.unsqueeze(1), b.unsqueeze(0),dim=2)    
print('固有算法求cos相似度\n', d)    

这段代码实际上没有实现。你会发现只有对角线的位置计算正确,而非对角线位置错误。

a:[2,4]  
按行分块 结果可写为:
A1
A2

b:[2,4]按行分块
B1
B2

a*b.t()相当于是
A1B1 A1B2
A2B1 A2B2

a*b.t()/aa/bb相当于是
A1B1/|A1||B1| A1B2/|A1||B1|
A2B1/|A2||B2| A2B2/|A2||B2|

|A1|表示向量2范数。就是求aa和bb的过程。那这时就能知道为什么错了。

那应该怎样才能实现cos这个功能呢?实际上就是怎么得到除法中的被除数。仍然是矩阵乘法。

e = ab / (aa.matmul(bb.t()))
print('改进后的余弦相似度计算\n', e)

代入验证。可以发现实现了余弦相似度计算功能。

写这个的目的就是,每次遇到了都要花上半天才能搞明白。

cosine_similarity的结果,每个位置是算[1,h]和[1,h]向量的余弦相似度。这个计算就是ab/|a||b|。而|a|=(a1*a1+a2*a2+...)**0.5。

  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值