官方介绍:torch.mean — PyTorch 1.11.0 documentation
torch.
mean
(input, dim, keepdim=False, *, dtype=None, out=None)
参数
input,要输入的张量
dim,要求均值的维度
keepdim,求完均值之后是否要保留该维度
dtype,数据格式,(输入整数会被识别为long报错)
1、当dim为空时,输出全部值的平均数
2、当dim为常数时,输出延该维度求完平均数之后的张量
这是官方实例
a = torch.randn(4, 4)
tensor([[-0.3841, 0.6320, 0.4254, -0.7384],
[-0.9644, 1.0131, -0.6549, -1.4279],
[-0.2951, -1.3350, -0.7694, 0.5600],
[ 1.0842, -0.9580, 0.3623, 0.2343]])
torch.mean(a, 1)
tensor([-0.0163, -0.5085, -0.4599, 0.1807])
torch.mean(a, 1, True)
tensor([[-0.0