>>> import torch
>>> x = torch.arange(15).view(3,5)*1.0
>>> print(x)
tensor([[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[10., 11., 12., 13., 14.]])
torch.mean(input, dim, keepdim)
>>> x_mean0 = torch.mean(x, dim=0, keepdim=True)
>>> print(x_mean0)
tensor([[5., 6., 7., 8., 9.]])
>>>
>>> x_mean1 = torch.mean(x, dim=1 ,keepdim=True)
>>> print(x_mean1)
tensor([[ 2.],
[ 7.],
[12.]])
>>>
torch.max(input, dim, keepdim)
>>> values0, indices0 = torch.max(x, dim=0 ,keepdim=True)
>>> print(values0)
tensor([[10., 11., 12., 13., 14.]])
>>> print(indices0)
tensor([[2, 2, 2, 2, 2]])
>>>
>>> values1, indices1 = torch.max(x, dim=1 ,keepdim=True)
>>> print(values1)
tensor([[ 4.],
[ 9.],
[14.]])
>>> print(indices1)
tensor([[4],
[4],
[4]])
>>>
keepdim的作用
>>> x_mean = torch.mean(x, dim=0, keepdim=True)
>>> print(x_mean)
tensor([[5., 6., 7., 8., 9.]])
>>> x_mean = torch.mean(x, dim=0,keepdim=False)
>>> print(x_mean)
tensor([5., 6., 7., 8., 9.])
>>>