Pytorch计算余弦相似度距离——torch.nn.CosineSimilarity函数中的dim参数使用方法

PyTorch中CosineSimilarity计算高维特征相似度

前言


前言

现在要使用Pytorch中自带的torch.nn.CosineSimilarity函数计算两个高维特征图(B,C,H,W)中各个像素位置的特征相似度,即特征图中的每个像素位置上的一个(B,C,1,1)的向量为该位置的特征,总共有BxHxW个特征。

一、官方函数用法

        意思是 dim参数指定了函数在哪个维度上进行余弦距离计算,计算之后该维度会消失,而其他维度的形状保持不变。但是现有的大多数博客将dim的用法复杂化,因此这里进行简单的实验验证,来验证一下上述说法。

二、实验验证

1.计算高维数组中各个像素位置的余弦距离

创造高维数组,在通道维度(即dim=1)上进行向量的余弦距离计算,并查看其中第一批数据中的位置(0,0)

torch.nn.CosineSimilarityPyTorch中用于计算余弦相似度函数余弦相似度用向量空间中两个向量夹角的余弦值作为衡量两个个体间差异的大小,余弦值越接近1,表明夹角越接近0度,即两个向量越相似,该函数利用的就是这一理论思想,通过计算两个向量夹角的余弦值来衡量向量之间的相似度值[^4]。 以下给出使用示例: ```python import torch import torch.nn.functional as F import math # 定义输入张量 a = torch.tensor([[1, 2], [3, 4]], dtype=torch.float) b = torch.tensor([[5, 6], [7, 8]], dtype=torch.float) # 自定义函数计算余弦相似度 def check(vec_a, vec_b): dot = 0 for i in range(len(vec_a)): dot += vec_a[i]*vec_b[i] vec_a_sq_sum = math.sqrt(sum([item*item for item in vec_a])) vec_b_sq_sum = math.sqrt(sum([item*item for item in vec_b])) return dot/(vec_a_sq_sum*vec_b_sq_sum) # 使用dim=0计算余弦相似度 res_0 = F.cosine_similarity(a, b, dim=0) check1_0 = check([1,3], [5,7]) check2_0 = check([2,4], [6,8]) print("dim=0时F.cosine_similarity计算结果:", res_0) print("dim=0时自定义函数计算结果:", check1_0, check2_0) # 使用dim=1计算余弦相似度 res_1 = F.cosine_similarity(a, b, dim=1) check1_1 = check([1,2], [5,6]) check2_1 = check([3,4], [7,8]) print("dim=1时F.cosine_similarity计算结果:", res_1) print("dim=1时自定义函数计算结果:", check1_1, check2_1) ``` 对于二维矩阵,dim=0表示对应列的列向量之间进行cos相似度计算dim=1表示对应行的行向量之间进行cos相似度计算[^3]。
评论 1
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值