一个batch_size=32模型训练二分类预测输出的结果
output=torch.tensor([[0.4386, 0.6951],
[0.8676, 0.8337],
[0.5899, 0.5169],
[0.4655, 0.7301],
[0.5195, 0.8004],
[0.0000, 0.0000],
[0.9471, 0.9369],
[0.5637, 0.8579],
[0.7036, 0.4013],
[0.0772, 0.0000],
[0.4833, 0.7533],
[0.5751, 0.8727],
[0.7379, 0.6507],
[0.5818, 0.5124],
[0.5079, 0.7853],
[0.3538, 0.5229],
[0.8047, 0.9168],
[0.5269, 0.8100],
[0.4793, 0.6636],
[0.5675, 0.5823],
[0.0289, 0.0000],
[0.5372, 0.7232],
[0.5348, 0.8203],
[0.5092, 0.5612],
[0.5882, 0.7212],
[0.5012, 0.7163],
[0.6365, 0.8210],
[0.0000, 0.0475],
[0.5286, 0.7455],
[0.5431, 0.8311],
[0.7555, 0.8213],
[0.6197, 0.8221]])
output.argmax(dim=-1)
每一行的最大值下标
tensor([1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1])
output.argmax(dim=0)
在列上最大值的下标
tensor([6, 6])
aaa=torch.tensor([[1.8742, 0.0000],
[2.8360, 0.0000],
[2.2612, 0.0000],
[2.0899, 0.0000],
[1.4432, 0.0000],
[2.1185, 0.0000],
[1.9332, 0.0000],
[1.9641, 0.0000],
[2.1104, 0.0000],
[1.8978, 0.0000],
[1.7399, 0.0000],
[0.7568, 0.0000],
[2.1608, 0.0000],
[2.4813, 0.0000],
[1.7124, 0.0000],
[1.7205, 0.0000],
[0.8863, 0.0000],
[0.5090, 0.0000],
[1.3449, 0.0000],
[0.4651, 0.0000],
[1.7381, 0.0000],
[0.4817, 0.0000],
[2.4699, 0.0000],
[1.9334, 0.0000],
[2.2564, 0.0000],
[0.4286, 0.0000],
[2.1439, 0.0000],
[0.2625, 0.0000],
[1.4068, 0.0000],
[2.1745, 0.0000],
[1.7621, 0.0000],
[1.5503, 100]])
aaa.argmax(dim=0)
在列上最大值下标
tensor([ 1, 31])
aaa.argmax(dim=-1)
在行上最大值下标
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1])
max
aaa=torch.tensor([[1.8742, 0.0000],
[2.8360, 0.0000],
[2.2612, 0.0000],
[2.0899, 0.0000],
[1.4432, 0.0000],
[2.1185, 0.0000],
[1.9332, 0.0000],
[1.9641, 0.0000],
[2.1104, 0.0000],
[1.8978, 0.0000],
[1.7399, 0.0000],
[0.7568, 0.0000],
[2.1608, 0.0000],
[2.4813, 0.0000],
[1.7124, 0.0000],
[1.7205, 0.0000],
[0.8863, 0.0000],
[0.5090, 0.0000],
[1.3449, 0.0000],
[0.4651, 0.0000],
[1.7381, 0.0000],
[0.4817, 0.0000],
[2.4699, 0.0000],
[1.9334, 0.0000],
[2.2564, 0.0000],
[0.4286, 0.0000],
[2.1439, 0.0000],
[0.2625, 0.0000],
[1.4068, 0.0000],
[2.1745, 0.0000],
[1.7621, 0.0000],
[1.5503, 100]])
aaa.max(dim=0)
取出在列上最大值及其下标
torch.return_types.max(
values=tensor([ 2.8360, 100.0000]),
indices=tensor([ 1, 31]))
aaa.max(dim=1)
torch.return_types.max(
values=tensor([ 1.8742, 2.8360, 2.2612, 2.0899, 1.4432, 2.1185, 1.9332,
1.9641, 2.1104, 1.8978, 1.7399, 0.7568, 2.1608, 2.4813,
1.7124, 1.7205, 0.8863, 0.5090, 1.3449, 0.4651, 1.7381,
0.4817, 2.4699, 1.9334, 2.2564, 0.4286, 2.1439, 0.2625,
1.4068, 2.1745, 1.7621, 100.0000]),
indices=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1]))