torch.
mean
(input, dim, keepdim=False, *, out=None) → Tensor
第二个参数dim表示对哪个维度求均值
a = torch.randn(2, 3)
b = a.mean(1)
c = a.mean(-1)
print(a)
print(b)
print(c)
tensor([[-1.2267, 0.4601, -1.7988],
[ 0.2121, -1.5083, -0.2360]])
tensor([-0.8551, -0.5108])
tensor([-0.8551, -0.5108])
通过上述代码的结果,可以看出tensor.mean(1)和tensor.mean(-1)都是对一行的元素求均值