Pytorch中tensor维度和torch.max()函数中dim参数的理解
维度
参考了 https://blog.csdn.net/qq_41375609/article/details/106078474 ,
对于torch中定义的张量,感觉上跟矩阵类似,不过常见的矩阵是二维的。当定义一个多维的张量时,比如使用 a =torch.randn(2, 3, 4) 创建一个三维的张量,返回的是一个
[
[
[-0.5166, 0.8298, 2.4580, -1.9504],
[ 0.1119, -0.3321, -1.3478, -1.9198],
[ 0.0522, -0.6053, 0.8119, -1.3469]
],
[
[-0.3774, 0.9283, 0.7996, -0.3882],
[-1.1077, 1.0664, 0.1263, -1.0631],
[-0.9061, 1.0081, -1.2769, 0.1035]
]
]
当使用 a.size() 返回维度结果时,结果为 torch.Size([2, 3, 4]),这里面有三个数值,数值的个数代表维度的个数 ,所以这里有三个维度,可以理解为一个中括号代表一个维度。数值 2 处在第一个位置,第一个位置代表是第一维度,2代表这个维度有两个元素,也就是第一个 [ ] 里面两个元素,3代表在第二个维度,也就是在第一个 [ ] 中的两个元素里面,又有三个元素,依次类推。这里格式十分固定,一旦定义,必须是一个元素里面有两个元素,这两个元素中每个再包含三个元素,再包含,依此类推,否则会报错。类似与树,维数等于相似的树的深度-1(以根为第一层),每一层就是一维。
如生成一个
torch.tensor(
[
[
[1, 2, 3, 4]
[3, 4, 2, 1]
[4, 1, 2, 3]
]
[
[2, 1, 3, 4]
[3, 4, 2, 1]
[4, 1, 2, 3]
]
]
)
方便理解,以下图的形式展示,这里竖线代表一个维度,竖线上所有节点代表同一维度的所有元素。在下面所有图中,同颜色的元素都是按照从上往下按顺序排列的。
一、dim参数
在使用torch.max()函数和其他的一些函数时,会有dim这个参数。官网中定义使用torch.max()函数时,生成的张量维度会比原来的维度减少一维,除非原来的张量只有一维了. 要减少消去的是哪一维便是由dim参数决定的,dim参数实际上指的是我们计算过程中所要消去的维度。因为在比较时必须要指定使用哪些数字来比较 ,或者进行其他计算,比如 max 使一些数据中只要大的,sum只取和的结果,自然就会删减其他的一些数据从而引起降维。
以上面生成的三维的张量为例子,有三个维度,但是维度的数字顺序是 dim = 0, 1, 2;
当指定torch.max(a,dim=0)时,也就是要删除第一个维度,删除第一个维度的话,那还剩下两个维度,也就是dim =1 ,2 。 剩下的两个维度的参数是 3 和 4,那么删除第一个维度后应该剩下torch.tensor(3, 4)这样形式的张量, dim参数可以使用负数,也就是负的索引,与列表中的索引相似,在本例中dim = -1 与dim = 2是一样的。
返回的
values=tensor([[