x=torch.Tensor(range(1,17))
x=x.reshape(2,2,2,2)
print('x: ',x)
print('mean(0): ',x.mean(0))
print('mean(1): ',x.mean(1))
print('mean(2): ',x.mean(2))
print('mean(3): ',x.mean(3))
out:
x:
tensor([[[[ 1., 2.],
[ 3., 4.]],
[[ 5., 6.],
[ 7., 8.]]],
[[[ 9., 10.],
[11., 12.]],
[[13., 14.],
[15., 16.]]]])
mean(0):
tensor([[[ 5., 6.],
[ 7., 8.]],
[[ 9., 10.],
[11., 12.]]])
mean(1):
tensor([[[ 3., 4.],
[ 5., 6.]],
[[11., 12.],
[13., 14.]]])
mean(2):
tensor([[[ 2., 3.],
[ 6., 7.]],
[[10., 11.],
[14., 15.]]])
mean(3):
tensor([[[ 1.5000, 3.5000],
[ 5.5000, 7.5000]],
[[ 9.5000, 11.5000],
[13.5000, 15.5000]]])
mean函数中的参数dim代表在第几维度求平均数。
dim=0时,在第一个维度,也就是(1+9)/2=5,(2+10)/2=6,.......
dim=1时,在第二个维度,也就是(1+5)/2=3,(2+6)/2=4,........
dim=2时,在第三个维度,也就是(1+3)/2=2,(2+4)/2=3,....
dim=3时,在第四个维度,也就是(1+2)/2=1.5,(3+4)/2=3.5,.......