在处理图像时,会用到这个函数。记录一下我自己跌进去的输出大坑。
情况1
输入1:
import torch
a = torch.tensor([[
[
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]
],
[
[0, 0, 0, 0],
[0, 1, 1, 0],
[0, 0, 0, 0]
]]])
b = torch.argmax(a, dim=1)
print(a.shape)
print(b)
输出1:
torch.Size([1, 2, 3, 4])
tensor([[[0, 0, 0, 0],
[0, 1, 1, 0],
[0, 0, 0, 0]]]);
情况2
输入2:
import torch
a = torch.tensor([[
[
[0, 0, 0, 0],
[0, 1, 1, 0],
[0, 0, 0, 0]
],
[
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]
]]])
b = torch.argmax(a, dim=1)
print(a.shape)
print(b)
输出2:
torch.Size([1, 2, 3, 4])
tensor([[[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]]])
情况3
输入3:
import torch
a = torch.tensor([[
[
[0, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 0]
],
[
[0, 0, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 0]
]]])
b = torch.argmax(a, dim=1)
print(a.shape)
print(b)
输出3:
torch.Size([1, 2, 3, 4])
tensor([[[0, 0, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 0]]])
情况4
输入4:
import torch
a = torch.tensor([[
[
[0, 0, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 0]
],
[
[0, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 0]
]]])
b = torch.argmax(a, dim=1)
print(a.shape)
print(b)
输出4:
torch.Size([1, 2, 3, 4])
tensor([[[0, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 0]]])
torch.argmax()的编码是依据所要排序的dim的第一维来确定编号的。