官网:https://pytorch.org/docs/stable/torch.html#torch.max
torch.max()和torch.min()是比较tensor大小的函数。两者用法相同,所以就总结了一个。
(1)不指定比较维度:torch.max(input)
x = torch.rand(1,3)
print(x)
print(torch.min(x))
y = torch.rand(2,3)
print(y)
print(torch.min(y))
output:
tensor([[0.4094, 0.0262, 0.9132]])
tensor(0.0262)
tensor([[0.4712, 0.3108, 0.3703],
[0.0609, 0.8676, 0.7341]])
tensor(0.0609)
(2)指定比较维度:torch.max(input,dim)
output 返回tuple:tuple[0] -> 比较结果; tuple[1] ->所在索引
y = torch.rand(2,3)
print(y)
print(torch.max(y,0))
output:
tensor([[0.7573, 0.4121, 0.0922],
[0.0562, 0.1346, 0.5164]])
(tensor([0.7573, 0.4121, 0.5164]), tensor([0, 0, 1]))
(3)两个tensor相比较:不一定是相同大小结构,若不是相同大小结构,必须满足可广播
相同结构:比较相同位置的返回结果
x = torch.rand(2,3)
y = torch.rand(2,3)
print(x)
print(y)
print(torch.max(x,y))
output:
tensor([[0.9054, 0.4904, 0.4252],
[0.5209, 0.8509, 0.7347]])
tensor([[0.2347, 0.4457, 0.4466],
[0.5157, 0.5463, 0.0814]])
tensor([[0.9054, 0.4904, 0.4466],
[0.5209, 0.8509, 0.7347]])
不是相同结构的,按照广播原理将维度少的那个做一个数据复制再比较。
x = torch.rand(1,3)
y = torch.rand(2,3)
print(x)
print(y)
print(torch.max(x,y))
output:
tensor([[0.2240, 0.1759, 0.3040]])
tensor([[0.6603, 0.1693, 0.5366],
[0.4192, 0.4316, 0.0386]])
tensor([[0.6603, 0.1759, 0.5366],
[0.4192, 0.4316, 0.3040]])