pytorch余弦相似度矩阵与角度矩阵

向量余弦相似度

 cosine_similarity ( x , y ) = cos ⁡ ⟨ x , y ⟩ = x T y ∥ x ∥ ∥ y ∥ \text{ cosine\_similarity}\left( \mathbf{x},\mathbf{y}\right)=\cos \left\langle \mathbf{x},\mathbf{y} \right\rangle=\frac{\mathbf{x}^T\mathbf{y}}{\|\mathbf{x}\|\|\mathbf{y}\|}  cosine_similarity(x,y)=cosx,y=xyxTy
其中 ⟨ x , y ⟩ \left\langle \mathbf{x},\mathbf{y} \right\rangle x,y表示 x \mathbf{x} x y \mathbf{y} y的角度

余弦相似度矩阵

接下来的问题是考虑矩阵 A = ( a 1 T a 2 T ⋮ a m T ) , a i ∈ R k , B = ( b 1 T b 2 T ⋮ b n T ) , b i ∈ R k \mathbf{A}=\begin{pmatrix} \mathbf{a}_1^T\\ \mathbf{a}_2^T\\ \vdots\\ \mathbf{a}_m^T \end{pmatrix},\mathbf{a}_i\in\mathbb{R}^k,\mathbf{B}=\begin{pmatrix} \mathbf{b}_1^T\\ \mathbf{b}_2^T\\ \vdots\\ \mathbf{b}_n^T \end{pmatrix},\mathbf{b}_i\in\mathbb{R}^k A=a1Ta2TamT,aiRk,B=b1Tb2TbnT,biRk
计算余弦相似度矩阵 C ∈ R m × n \mathbf{C}\in\mathbb{R}^{m\times n} CRm×n,
其中 c i j =  cosine_similarity ( a i , b j ) c_{ij}=\text{ cosine\_similarity}\left( \mathbf{a}_i,\mathbf{b}_j\right) cij= cosine_similarity(ai,bj)

第一种

先单位化,这样可以不用单位化
这里要注意防止分母为0,所以有一个eps
torch.einsum就是一种矩阵乘法

def cosine_similarity(a, b, eps=1e-7):
    """

    :param a: m*k
    :param b: n*k
    :return: m*n cosine similarity
    """
    a = a / torch.max(torch.norm(a, p=2, dim=-1, keepdim=True), torch.tensor(eps))
    b = b / torch.max(torch.norm(b, p=2, dim=-1, keepdim=True), torch.tensor(eps))
    return torch.einsum('ni,mi->nm', a, b)

第二种

其实就是最后一步直接矩阵乘法

def cosine_similarity(a, b, eps=1e-7):
    """

    :param a: m*k
    :param b: n*k
    :return: m*n cosine similarity
    """
    a = a / torch.max(torch.norm(a, p=2, dim=-1, keepdim=True), torch.tensor(eps))
    b = b / torch.max(torch.norm(b, p=2, dim=-1, keepdim=True), torch.tensor(eps))
    return torch.mm(a, b.T)

第三种(推荐)

from torch.nn.functional import cosine_similarity

cosine_similarity(a.unsqueeze(1), b.unsqueeze(0), dim=-1)

可以查看这里https://pytorch.org/docs/stable/generated/torch.nn.CosineSimilarity.html?highlight=cosinesimilarity#torch.nn.CosineSimilarity
在这里插入图片描述
A ∈ R m × k \mathbf{A}\in\mathbb{R}^{m\times k} ARm×k,经过.unsqueeze(1)会变成 m × 1 × k m\times 1\times k m×1×k的矩阵
B ∈ R n × k \mathbf{B}\in\mathbb{R}^{n\times k} BRn×k,经过.unsqueeze(0)会变成 1 × n × k 1\times n\times k 1×n×k的矩阵

dim=-1,这意味着最后会变成一个 m × n m\times n m×n的矩阵
因为形状不一样,所以会广播成 m × n × k m\times n\times k m×n×k的矩阵

举个例子
3 × 2 3\times 2 3×2的矩阵广播成 3 × 4 × 2 3\times 4\times 2 3×4×2
a经过unsqueeze(1)会变成 3 × 1 × 2 3\times 1\times 2 3×1×2的矩阵
最后的+zeros是为了触发广播
在这里插入图片描述
最后计算余弦相似度的时候
a[0,0]得到的是一个2个元素的向量
就是
a[0,0]和b[0,0]计算余弦相似度
a[0,0]和b[0,1]计算余弦相似度

a[0,1]和b[0,0]计算余弦相似度
a[0,1]和b[0,1]计算余弦相似度

a[m,n]和b[m,n]计算余弦相似度

角度矩阵

得到余弦相似度矩阵后计算 arccos ⁡ \arccos arccos,对应pytorch的torch.acos
需要注意的是,pytorch的acos有点问题https://github.com/pytorch/pytorch/issues/8069
如果数值接近-1或1,则会返回nan
解决方法是 ± e p s \pm eps ±eps

eps=1e-7
matrix_cos=matrix_cos.clamp(min=-1+eps,max=1+eps)

matrix_angle=torch.acos(matrix_cos)
  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Nightmare004

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值