argmax
函数用于找到数组中最大元素的索引。
行或列的最大值:
import numpy as np
arr_2d = np.array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# 使用 argmax 找到每列的最大值的索引(沿着行的方向,axis=0)
max_indices_per_column = np.argmax(arr_2d, axis=0)
print("每列的最大值的索引:", max_indices_per_column)
# 输出: [2 2 2],因为每列的最大值都在第三行。
# 使用 argmax 找到每行的最大值的索引(沿着列的方向,axis=1)
max_indices_per_row = np.argmax(arr_2d, axis=1)
print("每行的最大值的索引:", max_indices_per_row)
# 输出: [2 2 2],因为每行的最大值都在第三列。
import torch
# 创建一个二维 PyTorch 张量
tensor_2d = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# 使用 argmax 找到每列的最大值的索引(沿着行的方向,dim=0)
max_indices_per_column = torch.argmax(tensor_2d, dim=0)
print("每列的最大值的索引:", max_indices_per_column)
# 输出: tensor([2, 2, 2]),因为每列的最大值都在第三行。
# 使用 argmax 找到每行的最大值的索引(沿着列的方向,dim=1)
max_indices_per_row = torch.argmax(tensor_2d, dim=1)
print("每行的最大值的索引:", max_indices_per_row)
# 输出: tensor([2, 2, 2]),因为每行的最大值都在第三列。
整个矩阵的最大值:
import numpy as np
# 创建一个二维 NumPy 数组
arr_2d = np.array([[1, 2, 3],
[4, 5, 9],
[7, 8, 6]])
# 使用 argmax 找到整个矩阵中最大值的索引
max_index = np.argmax(arr_2d)
# 由于这是一个二维数组,我们需要将一维索引转换为二维坐标
row = max_index // arr_2d.shape[1]
col = max_index % arr_2d.shape[1]
print("最大值的索引(行,列):", (row, col))
# 最大值的索引(行,列): (1, 2)
import torch
# 创建一个二维 PyTorch 张量
tensor_2d = torch.tensor([[1, 2, 3],
[4, 5, 9],
[7, 8, 6]])
# 将二维张量展平为一维张量
tensor_flat = tensor_2d.view(-1)
# 使用 argmax 找到整个张量中最大值的索引
max_index = torch.argmax(tensor_flat)
# 如果需要,可以将一维索引转换为原始的二维坐标
row = max_index // tensor_2d.size(1)
col = max_index % tensor_2d.size(1)
print("最大值的索引(行,列):", (row, col))
# 最大值的索引(行,列): (tensor(1), tensor(2))