torch.mean求完平均直接返回张量
torch.max求完最大返回张量和最大值对应的索引
在使用时需要注意
举个例子:
t=torch.randn(3,4,2)
print(t)
tensor([[[-0.8378, 1.0686],
[ 1.0435, -0.5058],
[-0.2627, -0.0802],
[-0.1599, 0.5249]],
[[-0.8277, 1.3927],
[-1.0998, 1.1341],
[ 1.7856, -0.7895],
[ 1.7492, -0.9021]],
[[ 0.3713, 1.4117],
[ 0.8045, -0.7918],
[-1.3890, 0.4643],
[-1.4575, -0.3566]]])
torch.max(t,dim=1,keepdim=False)
torch.return_types.max(
values=tensor([[1.0435, 1.0686],
[1.7856, 1.3927],
[0.8045, 1.4117]]),
indices=tensor([[1, 0],
[2, 0],
[1, 0]]))
torch.mean(t,dim=1,keepdim=False)
tensor([[-0.0542, 0.2519],
[ 0.4018, 0.2088],
[-0.4177, 0.1819]])
所以根据需要选择torch.max(*)[0] 选择值
选择torch.max(*)[1] 选择索引