torch.max()
-
torch.max(input: Tensor) --> Tensor
: 返回所有元素的最大值 -
torch.max(input: Tensor, dim: int, keepdim: bool = False) --> Tuple(Tensor, Tensor)
: 返回指定维度的最大值和对应的索引- 如果
keepdim=True
,维度不变,指定维度的大小为1
- 如果
-
torch.Tensor.max()
参考torch.max()
-
代码示例
import torch x = torch.randn(3, 4) max_value = x.max() max_values, max_indices = x.max(dim=0)
torch.argmax()
-
torch.argmax(input: Tensor) --> Tensor
:返回所有元素最大值的索引 -
torch.argmax(input: Tensor, dim: int, keepdim: bool = False)
: 返回指定维度的最大值对应的索引 -
torch.Tensor.argmax()
参考torch.argmax()
-
代码示例
x = torch.randn(3, 4) max_index = x.argmax() max_indices = x.argmax(dim=0)
torch.maximum()
torch.maximum(input: Tensor, other: Tensor) --> Tensor
: 返回input tensor
和other tensor
的逐元素最大值other
的维度可以与input
的维度不同,但维度较少的tensor
的维度大小要与维度较多的tensor
的最后维度大小保持一致torch.Tensor.maximum()
参考torch.maximum()
- 代码示例
# 维度和维度大小相同
x1 = torch.randn(2, 3, 4)
y1 = torch.randn(2, 3, 4)
# 维度不同,维度较少的tensor的维度大小与维度较多的tensor的最后维度大小保持一致
x2 = torch.randn(1, 2, 3, 4)
y2 = torch.randn(2, 3, 4)
z1 = torch.maximum(x1, y1)
z2 = torch.maximum(x2, y2)
print(z1.shape, z2.shape)
# torch.Size([2, 3, 4]) torch.Size([1, 2, 3, 4])