# 目的是求a、b之间的余弦相似度。一般是[bs,h]形状。以[2,4]为例。
a = torch.rand(2, 4)
b = torch.rand(2, 4)
print('初始输入\n', a, '\n', b)
ab = torch.matmul(a, b.t())
print('ab相乘\n', ab)
aa = a*a
print('a*a\n',aa)
aa = aa.sum(1, keepdim=True)
print('求和aa')
aa = aa ** 0.5
print('开方aa\n', aa)
bb = b * b
print('b*b\n',bb)
bb = bb.sum(1, keepdim=True)
print('求和bb')
bb = bb ** 0.5
print('开方bb\n', bb)
c = ab/aa/bb
print('自己算的cosi相似度\n', c)
d = torch.nn.functional.cosine_similarity(a.unsqueeze(1), b.unsqueeze(0),dim=2)
print('固有算法求cos相似度\n', d)
这段代码实际上没有实现。你会发现只有对角线的位置计算正确,而非对角线位置错误。
a:[2,4]
按行分块 结果可写为:
A1
A2
b:[2,4]按行分块
B1
B2
a*b.t()相当于是
A1B1 A1B2
A2B1 A2B2
a*b.t()/aa/bb相当于是
A1B1/|A1||B1| A1B2/|A1||B1|
A2B1/|A2||B2| A2B2/|A2||B2|
|A1|表示向量2范数。就是求aa和bb的过程。那这时就能知道为什么错了。
那应该怎样才能实现cos这个功能呢?实际上就是怎么得到除法中的被除数。仍然是矩阵乘法。
e = ab / (aa.matmul(bb.t()))
print('改进后的余弦相似度计算\n', e)
代入验证。可以发现实现了余弦相似度计算功能。
写这个的目的就是,每次遇到了都要花上半天才能搞明白。
cosine_similarity的结果,每个位置是算[1,h]和[1,h]向量的余弦相似度。这个计算就是ab/|a||b|。而|a|=(a1*a1+a2*a2+...)**0.5。