import torch
a = torch.randn((2,3,2))
a = a.int()
a = a.float() #为了简便起见,随机生成
print(a)
print(a.mean(dim=(0,2)))
print(a.mean(dim=(0,1)))
tensor([[[ 0., 0.],
[ 1., 0.],
[ 0., -1.]],
[[-1., 0.],
[-1., 1.],
[ 0., 1.]]])
tensor([-0.2500, 0.2500, 0.0000])
tensor([-0.1667, 0.1667])
先看第一个a.mean(dim=(0,2)),其意思是去除第0维,第2维,因为a的size为(2,3,2)
去除第0维,第2维,只剩下中间的size,那么结果的size为(3)。
数值的计算:
为了更直观的看,我们以这样的形式表示a
去除第0维后,size会变成(3,2),也就是这两个表格会变成一个表格,如下:
也就是把两个表格的数据对应位置相加再求平均
接下来去除第2维,size会变成(3),也就是这两列会变成一列,如下:
就是列对应位置相加再求各自平均
也就是输出结果里的tensor([-0.2500, 0.2500, 0.0000])
同样的a.mean(dim=(0,1)),就是将第二张图里的表格,合并成只有一行两列的表格。