最近在学沐神的动手学深度学习,在学线性代数篇时遇到了这个小问题。
总的来说,axis是轴的意思,1表示横轴,方向从左到右;0表示纵轴,方向从上到下。
接下来下看一下代码
A 先看二维数组
A = torch.arange(12).reshape(3,4)
print(A)
A_sum_axis0 = A.sum(axis=0)
A_sum_axis0, A_sum_axis0.shape
axis=0的情况下,sum函数就将第一个行维度3进行压缩。对3个行向量中的每个相对应的标量相加得到一个4维(行)向量。
0+4+8=12,1+5+9=15等等
同理,axis=1时将在列维度4进行压缩。在横向把各个元素相加得到一个3维(列)向量。
0+1+2+3=6等等
B 接着看三维矩阵
B = torch.arange(24).reshape(2,3,4)
print(B)
print("\n")
B_sum_axis0 = B.sum(axis=0)
B_sum_axis0, B_sum_axis0.shape
同样先看axis=0的情况:
B是个3维矩阵,shape为(2,3,4),第一个2表示有2个形状为(3,4)的二维矩阵。
所以axis=0呢,就选中了这两个二维矩阵。调用sum函数就会将这两个形状为(3,4)的二维矩阵按元素相加。
axis=1的情况呢, 跟二维矩阵A在axis=0类似,不过这里只是又多了一个矩阵而已。于是会分别在行维度3分别压缩这两个矩阵。
最后是axis=2,跟二维矩阵A在axis=1类似。在列维度4分别压缩这两个矩阵。
另外,如果指定了参数keepdims=true,就不会把那个维度去掉,而是保留为1。
总而言之呢,axis等于n,就相当于把第n维压缩(拍扁)。沐神还专门讲了哈哈