目录
torch常用API
1.torch.max(x,dim=1)
1.1定义
-
input输入的是一个tensor
-
dim是max函数索引的维度0/1,0是每列的最大值,1是每行的最大值
-
返回的是两个值:一个是每一行最大值的tensor组,另一个是最大值所在的位置
max_col_value = torch.max(x,dim=0)[0] # 每一列最大值
max_row_value = torch.max(x,dim=1)[0] # 每一行最大值
1.2example
x:
tensor([[0.5285, 0.1247, 0.8332, 0.5485],
[0.7917, 0.6138, 0.5881, 0.3381],
[0.4226, 0.6605, 0.8571, 0.0399],
[0.1716, 0.0609, 0.9712, 0.4838]])
torch.max(x,1):
(tensor([0.8332, 0.7917, 0.8571, 0.9712]), tensor([2, 0, 2, 2]))
torch.max(x,0):
(tensor([0.7917, 0.6605, 0.9712, 0.5485]), tensor([1, 2, 3, 0]))
torch.max(x,1)[0]:
tensor([0.8332, 0.7917, 0.8571, 0.9712])
torch.max(x,1)[1]:
tensor([2, 0, 2, 2])
torch.max(x,1)[1].data:
tensor([2, 0, 2, 2])
torch.max(x,1)[1].data.numpy():
[2 0 2 2]
torch.max(x,1)[1].data.numpy().squeeze():
[2 0 2 2]
torch.max(x,1)[0].data:
tensor([0.8332, 0.7917, 0.8571, 0.9712])
torch.max(x,1)[0].data.numpy():
[0.83318216 0.7917127 0.85708565 0.9711726 ]
torch.max(x,1)[0].data.numpy().squeeze():
[0.83318216 0.7917127 0.85708565 0.9711726 ]