以shape为(2, 3, 3)的tensor为例,理解x.mean或torch.mean的用法
1. 创建tensor
import torch
x = torch.arange(18, dtype=torch.float32).view(2, 3, 3)
2. dim=0时
dim_0 = x.mean(dim=0)
print(dim_0)
输出结果:
可以看出,dim=0时,在x的通道维度进行了求平均,以dim_0[0][0]为例,dim_0[0][0] = (x[0][0][0]+x[1][0][0])/2。
3. dim=1时
dim_1 = x.mean(dim=1)
print(dim_1)
输出结果:
可以看出,当dim=1时,在x的行方向(或者说高H方向)进行求平均,以dim_1[0][0]为例,dim_1[0][0] = (x[0][0][0]+x[0][1][0]+x[0][2][0])/3。
4. dim=2时
dim_2 = x.mean(dim=2)
print(dim_2)
可以看出,当dim=2时,在x的列方向(或者说宽W方向)进行求平均。
5. dim=[1, 2]
dim_1_2 = x.mean(dim=[1, 2])
print(dim_1_2)
输出结果:
在行方向求平均的结果上再沿列方向求平均
6. dim=[0, 1]
dim_0_1 = x.mean(dim=[0, 1])
print(dim_0_1)
输出结果:
![dim=0, 1
在沿通道方向求平均的结果上再沿行方向求平均
另一个参数keepdim,默认值为False
比较keepdim=False与keepdim=True时的输出结果:
可以看出,keepdim=True时维度保持不变
x.mean(dim)与torch.mean(x,dim)
x.mean(dim)与torch.mean(x,dim)能够实现相同的效果。
以上均为个人理解,如有错误或不妥之处,欢迎大家批评指正!!