对于Pytorch中dim=1的理解

本文详细介绍了Python中Tensor的shape和dim概念。shape表示数组的维度结构,例如(1,1,3,3)表示一个三维数组。dim用于指定在哪个维度上执行操作,如torch.mean()函数中,dim=2是对二维数组按列求平均,dim=3是对一维数组求平均。在模型预测中,理解dim有助于解析输出结果,例如当模型输出为(1,10)时,dim=1可以找到一维数组内的最大值。
摘要由CSDN通过智能技术生成

目  录

1 理解shape

2 理解dim

3 理解模型预测中的dim


1 理解shape

对于python中shape的理解:

(1,2) 表示1个一维数组,每个一维数组长度为2;

(1,2,3) 表示1个二维数组,每个二维数组有2个一维数组,每个一维数组长度为3;

(1,2,3,4) 表示1个三维数组,每个三维数组有2个二维数组,每个二维数组有3个一维数组,每个一维数组长度为4。


以下面的tensor为例:

import torch
a = torch.tensor([[[[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]]]])
print(a.shape)

输出结果为torch.Size([1, 1, 3, 3]),表示1个三维数组,每个三维数组中有1个二维数组,每个二维数组中有3个一维数组,一维数组分别为[1, 2, 3]、[4, 5, 6]、[7, 8, 9],每个一维数组的长度为3。如果我们对tensor进行索引

print(a[0][0][0])
print(a[0][0][1])
print(a[0][0][2])

结果分别为:

tensor([1, 2, 3])
tensor([4, 5, 6])
tensor([7, 8, 9])

要注意,只有一个三维数组,所以第一个索引值只能为0,否则就会报错超出索引值;只有一个二维数组,同理第二个索引值也只能为0;有三个一维数组,第三个索引值可以是0、1、2;每个一维数组长度为3,第四个索引值也可以是0、1、2。


2 理解dim

然后我们再通过torch.mean()函数来理解dim:

a = a.float()  # 先转换成float格式,否则会报错
print(torch.mean(a, dim=2))
print(torch.mean(a, dim=3))

运行结果如下:

tensor([[[4., 5., 6.]]])
tensor([[[2., 5., 8.]]])

通过对比可以看出,对于该数组,dim可取的值为0、1、2、3。dim=2意味着在二维数组上进行求平均值的操作,即对一个矩阵按列求平均值;dim=3意味着在一维数组内进行求平均值的操作,即对每个一维数组求平均值。


3 理解模型预测中的dim

再理解在模型预测中遇到的dim,模型输出的数组outputs为(1, 10),即一个一维数组,一维数组中有10个元素,对于该函数:

torch.max(outputs, dim=1)

dim=1即在outputs中的一维数组内取最大值。

  • 7
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值