本篇从两个例子argmax和sum两个函数来进行理解,是我的一个暂时性理解,不知道对不对,正确性有待验证。
axis可以理解为从哪个维度来计算:比如axis=0就表示从第零个维度开始计算。 (这里维度从0开始比较好理解)
如b[i][j]这里有2个维度,axis=0表示在i方向上进行变化,即b[0][j] b[1][j]。
输出的形状应该和b[0][j] 和b[1][j]中的任何一个一样,即1*3的一个张量。
二维:
b=torch.tensor([[1,2,3],[3,5,4]])
print(b.argmax(axis=0))
print(b.sum(axis=0))
首先b是一个二维张量,argmax返回的是最大值对应的下标索引(索引从零开始),sum返回求和结果。
argmax函数:对应的就是求每一列中的最大值的索引。
sum函数:对应的就是求每一列的元素的和。
这两者的答案都是按照b[i][j]中第0个维度(即维度i)的变化来进行计算的。
因为第0个维度在二维张量b中表示的是行,所以就是按照行来计算
所以当axis为1时就表示按第1个维度来进行计算,即按列:
b=torch.tensor([[1,2,3],[3,5,4]])
print(b.argmax(axis=1))
print(b.sum(axis=1))
三维:
b=torch.tensor(
[
[
[1,2,13,4],
[14,5,6,7],
[10,9,7,8]
],
[
[4,3,12,1],
[7,12,4,13],
[13,4,7,15]
]
]
)
print(b.argmax(axis=0))
print(b.sum(axis=0))
对于b[i][j][k]中axis=0时,表示按照第一个维度进行计算,最后计算出来的形状应该是[j][k]的形状。因为b是一个2*3*4张量,所以计算出来答案是一个3*4的形状。
如果把这个2*3*4看做一个立方体,那么就是把第二层的3*4张量和第二层的3*4张量来进行运算。
即 1和4 2和3 13和12 4和1
......
argmax:1和4最大是4,索引为1;2和3最大是3,索引为1;13和12最大是13 ,索引为0;4和1最大是4,索引为0......
sum:1+4=5;2+3=5;13+12=25;4+1=5......
当axis=1:
b=torch.tensor(
[
[
[1,2,13,4],
[14,5,6,7],
[10,9,7,8]
],
[
[4,3,12,1],
[7,12,4,13],
[13,4,7,15]
]
]
)
print(b.argmax(axis=1))
print(b.sum(axis=1))
axis=1就按照b[i][j][k]中的i来运算,输出形状为i*k即2*4。
参考: