import torch
a =torch.ones((2,5,4))
a.shape
torch.Size([2, 5, 4])
a.sum(axis=0)
tensor([[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.]])
a.sum(axis=1)
tensor([[5., 5., 5., 5.],
[5., 5., 5., 5.]])
a.sum(axis=2)
tensor([[4., 4., 4., 4., 4.],
[4., 4., 4., 4., 4.]])
a.sum(axis=[0,1])
tensor([10., 10., 10., 10.])
先 axis=0 后 axis=1
a.sum(axis=[0,2])
tensor([8., 8., 8., 8., 8.])
先 axis=0 后 axis=2
a.sum(axis=1,keepdims=True)
tensor([[[5., 5., 5., 5.]],
[[5., 5., 5., 5.]]])
keepdims=True 保留维度为1、