Pytorch计算余弦相似度距离——torch.nn.CosineSimilarity函数中的dim参数使用方法

前言


前言

现在要使用Pytorch中自带的torch.nn.CosineSimilarity函数计算两个高维特征图(B,C,H,W)中各个像素位置的特征相似度,即特征图中的每个像素位置上的一个(B,C,1,1)的向量为该位置的特征,总共有BxHxW个特征。

一、官方函数用法

        意思是 dim参数指定了函数在哪个维度上进行余弦距离计算,计算之后该维度会消失,而其他维度的形状保持不变。但是现有的大多数博客将dim的用法复杂化,因此这里进行简单的实验验证,来验证一下上述说法。

二、实验验证

1.计算高维数组中各个像素位置的余弦距离

创造高维数组,在通道维度(即dim=1)上进行向量的余弦距离计算,并查看其中第一批数据中的位置(0,0)上的两个向量之间的余弦距离:

>>> import torch
>>> import torch.nn as nn

>>> cos = nn.CosineSimilarity(dim=1, eps=1e-6)

>>> input1 = torch.randn(3, 64, 100, 128)
>>> input2 = torch.randn(3, 64, 100, 128)

>>> output = cos(input1, input2)

>>> output[0, 0, 0]
tensor(-0.1095)

2.验证高维数组中任意一个像素位置的余弦距离

将上述高维数组中的第一批数据中的位置(0,0)上的各个通道数值组成该位置上的特征向量,并计算两个向量间的余弦距离:

>>> import torch
>>> import torch.nn as nn

>>> cos2 = nn.CosineSimilarity(dim=0, eps=1e-6)

>>> input3=input1[0, :, 0, 0]
>>> input4=input2[0, :, 0, 0]

>>> output2 = cos2(input3, input4)

>>> output2
tensor(-0.1095)

发现两个距离是相同的,因此dim参数指定了函数在哪个维度上进行余弦距离计算,计算之后该维度会消失,而其他维度的形状保持不变。


总结

  Pytorch中自带的torch.nn.CosineSimilarity函数计算两个高维特征图中各个像素位置的特征相似度,其中dim参数指定了函数在哪个维度上进行余弦距离计算,计算之后该维度会消失,而其他维度的形状保持不变。

torch.nn.Sequential()函数PyTorch的一个类,可以用于构建神经网络模型。它允许我们按照顺序将多个层(layers)组合在一起,构建一个神经网络模型。 使用torch.nn.Sequential()函数时,我们可以将多个层作为参数传递给它,按照传递的顺序依次添加到模型。每个层都可以是PyTorch提供的预定义层(如全连接层、卷积层等),也可以是自定义的层。 下面是一个简单的示例,展示了如何使用torch.nn.Sequential()函数构建一个简单的前馈神经网络模型: ```python import torch import torch.nn as nn # 定义一个简单的前馈神经网络模型 model = nn.Sequential( nn.Linear(784, 256), # 全连接层1:输入大小为784,输出大小为256 nn.ReLU(), # ReLU激活函数 nn.Linear(256, 128), # 全连接层2:输入大小为256,输出大小为128 nn.ReLU(), # ReLU激活函数 nn.Linear(128, 10), # 全连接层3:输入大小为128,输出大小为10(输出类别数) nn.Softmax(dim=1) # Softmax激活函数,用于多分类问题 ) # 打印模型结构 print(model) ``` 在这个例子,我们使用torch.nn.Sequential()函数创建了一个模型,并按照顺序添加了三个全连接层以及两个激活函数。最后一个全连接层的输出大小设置为10,因为我们假设这是一个10类分类问题。最后,我们使用print语句打印了模型的结构。 这样,通过torch.nn.Sequential()函数,我们可以方便地按照顺序组合多个层,构建神经网络模型。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值