一、方法详解
含义:顾名思义,返回一个Tensor的均值
torch.mean(input, dim, keepdim=False)
这个用言语很难说明白,我们直接通过案例来理解掌握!
二、案例
import torch
x = torch.randn(4, 4)
print(x.size())
y = torch.mean(x, dim=0, keepdim=True)
print(y.size())
如果我们要沿dim=0这个维度相同,而且keepdim=True,那么输出张量的另一个维度dim=1就跟输入张量相同。
import torch
x = torch.randn(4, 4)
print(x.size())
y = torch.mean(x, dim=1, keepdim=True)
print(y.size())