1.区别
keepdim=True
运算完之后的维度和原来一样,原来是三维数组现在还是三维数组(不过某一维度变成了1);
keepdim=False
运算完之后一般少一维度,求平均变为1的那一维没有了;
axis=k
按第k维运算,其他维度不遍,第k维变为1。
import numpy as np
x=[
[[1,2,3,4],[5,6,7,8],[9,10,11,12]],
[[13,14,15,16],[17,18,19,20],[21,22,23,24]]
]
x = np.array(x)
x = torch.from_numpy(x)
x
Out[64]:
tensor([[[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]],
[[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24]]], dtype=torch.int32)
x = x.type(torch.float64)
x
Out[66]:
tensor([[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]],
[[13., 14., 15., 16.],
[17., 18., 19., 20.],
[21., 22., 23., 24.]]], dtype=torch.float64)
print("shape of x.mean(axis=0,keepdim=True):")
shape of x.mean(axis=0,keepdim=True):
x.shape
Out[68]: torch.Size([2, 3, 4])
print(x.mean(axis=0,keepdim=True).shape)
torch.Size([1, 3, 4])
print("shape of x.mean(axis=0,keepdim=False):") #[3, 4]
print(x.mean(axis=0,keepdim=False).shape)
shape of x.mean(axis=0,keepdim=False):
torch.Size([3, 4])