.max(dim=0)
是 PyTorch 中的张量操作,用于在指定维度上张量的最大值。
对于一个张量,.max(dim=0)
将返回两个张量:
- 第一个张量包含沿着指定维度(这里是维度 0,垂直方向上,每列)的最大值。
- 第二个张量包含每列最大值所在的索引(位置)。
举例来说,考虑一个二维张量 tensor
:
tensor = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
使用 .max(dim=0)
:
max_values, max_indices = tensor.max(dim=0)
max_values
为 tensor([7, 8, 9])
,即每列的最大值;
max_indices
为 tensor([2, 2, 2])
,即最大值所在的索引。