import torch
a= torch.ones((2,5,4))
a.shape
a.sum(axis=1,).shape
输出结果
torch.Size([2, 4])
思考下,若是axis为0时,输出结果应该是torch.Size([5, 4])
axis为2,torch.Size([2, 5])
import torch
a= torch.ones((2,5,4))
a.shape
a.sum(axis=1,).shape
输出结果
torch.Size([2, 4])
思考下,若是axis为0时,输出结果应该是torch.Size([5, 4])
axis为2,torch.Size([2, 5])