torch.max() 函数可以在指定维度上返回张量的最大值。它可以接受多个张量作为输入,返回一个张量,该张量包含所有输入张量中对应位置的最大值。
设置dim参数时,则返回每行的最大值及对应的索引位置(返回值也均为张量)。
官方文档:torch.max — PyTorch 2.0 documentation
语法:**torch.max(input, dim, keepdim=False, *, out=None)
参数:
• input (Tensor) – the input tensor.
• dim (int) – the dimension to reduce.
• keepdim (bool) – whether the output tensor has dim
retained or not. Default: False
.
import torch
# 创建一个 3x4 的张量
x = torch.randn(3, 4)
print(x)
# 返回整个张量的最大值
max_val = torch.max(x)
print(max_val)
# 返回每行的最大值及对应的索引位置(均为张量)
max_val, max_index = torch.max(x, dim=1)
print(max_val)
print(max_index)
# 返回每列的最大值
max_val, _ = torch.max(x, dim=0)
print(max_val)