格式
torch.max(input, dim, keepdim=False) → output tensors (max, max_indices)
举例
dim=0
a=torch.rand(2,3,4)
a
tensor([[[0.2912, 0.8998, 0.0251, 0.8944],
[0.9834, 0.8184, 0.1100, 0.5150],
[0.2375, 0.0448, 0.1398, 0.4678]],
[[0.9378, 0.0389, 0.6927, 0.4468],
[0.6021, 0.3559, 0.9302, 0.1069],
[0.3244, 0.5402, 0.5018, 0.6516]]])
torch.max(a,dim=0)
torch.return_types.max(
values=tensor([[0.9378, 0.8998, 0.6927, 0.8944],
[0.9834, 0.8184, 0.9302, 0.5150],
[0.3244, 0.5402, 0.5018, 0.6516]]),
indices=tensor([[1, 0, 1, 0],
[0, 0, 1, 0],
[1, 1, 1, 1]]))
输入的tensor.shape为[2,3,4],指定比较的dim=0,shape[0]=2,返回shape为[3,4]的tensor和对应的索引。
可以看到,返回tensor的值是取每个位置对应的2个数间的最大值。
dim=1
a
tensor([[[0.2912, 0.8998, 0.0251, 0.8944],
[0.9834, 0.8184, 0.1100, 0.5150],
[0.2375, 0.0448, 0.1398, 0.4678]],
[[0.9378, 0.0389, 0.6927, 0.4468],
[0.6021, 0.3559, 0.9302, 0.1069],
[0.3244, 0.5402, 0.5018, 0.6516]]])
torch.max(a,dim=1)
torch.return_types.max(
values=tensor([[0.9834, 0.8998, 0.1398, 0.8944],
[0.9378, 0.5402, 0.9302, 0.6516]]),
indices=tensor([[1, 0, 2, 0],
[0, 2, 1, 2]]))
输入的tensor.shape为[2,3,4],指定比较的dim=1,shape[1]=3,返回shape为[2,4]的tensor和对应的索引。
可以看到,返回tensor的值是取每个位置对应的3个数间的最大值。
dim=2
a
tensor([[[0.2912, 0.8998, 0.0251, 0.8944],
[0.9834, 0.8184, 0.1100, 0.5150],
[0.2375, 0.0448, 0.1398, 0.4678]],
[[0.9378, 0.0389, 0.6927, 0.4468],
[0.6021, 0.3559, 0.9302, 0.1069],
[0.3244, 0.5402, 0.5018, 0.6516]]])
torch.max(a,dim=2)
torch.return_types.max(
values=tensor([[0.8998, 0.9834, 0.4678],
[0.9378, 0.9302, 0.6516]]),
indices=tensor([[1, 0, 3],
[0, 2, 3]]))
输入的tensor.shape为[2,3,4],指定比较的dim=2,shape[2]=4,返回shape为[2,3]的tensor和对应的索引。
可以看到,返回tensor的值是取每个位置对应的4个数间的最大值。