torch.max()函数原型:
torch.max(input, dim, keepdim=False, *, out=None)
input:输入的Tensor
dim:将要在哪一个维度上进行比较
keepdim:是否保留维度信息,默认False。为True时,返回值将增加一个与dim相同的维度
当只传入input时,则在整个input中寻找最大值,只返回最大值,不返回索引。
input = torch.tensor([[6, 2, 7], [8, 1, 3], [4, 5, 9]])
print(input)
values = torch.max(input)
print(values)
tensor([[6, 2, 7],
[8, 1, 3],
[4, 5, 9]])
tensor(9)
当传入dim时,将返回dim对应维度上的最大值与对应的dim维度的索引。若input为一个二维的Tensor,那么dim=0时将在列上寻找最大值,并返回对应行的索引;dim=1时将在行上寻找最大值,并返回对应列的索引。
# dim=0 在列上寻找最大值 返回对应行的索引
input = torch.tensor([[6, 2, 7], [8, 1, 3], [4, 5, 9]])
print(input)
values, indices = torch.max(input, dim=0)
print(values)
print(indices)
output:
tensor([[6, 2, 7],
[8, 1, 3],
[4, 5, 9]])
tensor([8, 5, 9])
tensor([1, 2, 2])
# dim=1 在行上寻找最大值 返回对应列的索引
input = torch.tensor([[6, 2, 7], [8, 1, 3], [4, 5, 9]])
print(input)
values, indices = torch.max(input, dim=1)
print(values)
print(indices)
output:
tensor([[6, 2, 7],
[8, 1, 3],
[4, 5, 9]])
tensor([7, 8, 9])
tensor([2, 0, 2])
当传入keepdim时,则会保留维度的信息。以input为二维的Tensor举例。若dim=0,keepdim=True,将返回每列上的最大值,与对应行的索引,值与索引的格式保留了列的维度信息。可以看到当keepdim=False时,返回的值与索引的shape为3;keepdim=True时,返回的值与索引的shape为[1, 3],列的维度信息被保留下来。
# dim=0 keepdim=True 返回值保留列维度信息
input = torch.tensor([[6, 2, 7], [8, 1, 3], [4, 5, 9]])
print(input)
values, indices = torch.max(input, dim=0, keepdim=True)
print(values)
print(indices)
output:
tensor([[6, 2, 7],
[8, 1, 3],
[4, 5, 9]])
tensor([[8, 5, 9]])
tensor([[1, 2, 2]])
# dim=1 keepdim=True 返回值保留了行维度信息
input = torch.tensor([[6, 2, 7], [8, 1, 3], [4, 5, 9]])
print(input)
values, indices = torch.max(input, dim=1, keepdim=True)
print(values)
print(indices)
output:
tensor([[6, 2, 7],
[8, 1, 3],
[4, 5, 9]])
tensor([[7],
[8],
[9]])
tensor([[2],
[0],
[2]])